diff --git a/.gitignore b/.gitignore index cc526858..04e894c6 100644 --- a/.gitignore +++ b/.gitignore @@ -56,3 +56,9 @@ target/ # vim *.swp *.swo + +# PyCharm +.idea/ + +# env files +.env diff --git a/README.rst b/README.rst index be6bd3f0..507504a9 100644 --- a/README.rst +++ b/README.rst @@ -12,7 +12,7 @@ smart_open — utils for streaming large files in Python What? ===== -``smart_open`` is a Python 2 & Python 3 library for **efficient streaming of very large files** from/to storages such as S3, HDFS, WebHDFS, HTTP, HTTPS, SFTP, or local filesystem. It supports transparent, on-the-fly (de-)compression for a variety of different formats. +``smart_open`` is a Python 2 & Python 3 library for **efficient streaming of very large files** from/to storages such as S3, GCS, HDFS, WebHDFS, HTTP, HTTPS, SFTP, or local filesystem. It supports transparent, on-the-fly (de-)compression for a variety of different formats. ``smart_open`` is a drop-in replacement for Python's built-in ``open()``: it can do anything ``open`` can (100% compatible, falls back to native ``open`` wherever possible), plus lots of nifty extra stuff on top. @@ -80,6 +80,7 @@ Other examples of URLs that ``smart_open`` accepts:: s3://my_bucket/my_key s3://my_key:my_secret@my_bucket/my_key s3://my_key:my_secret@my_server:my_port@my_bucket/my_key + gs://my_bucket/my_blob hdfs:///path/file hdfs://path/file webhdfs://host:port/path/file @@ -174,6 +175,14 @@ More examples with open('s3://bucket/key.txt', 'wb', transport_params=transport_params) as fout: fout.write(b'here we stand') + # stream from GCS + for line in open('gs://my_bucket/my_file.txt'): + print(line) + + # stream content *into* GCS (write mode): + with open('gs://my_bucket/my_file.txt', 'wb') as fout: + fout.write(b'hello world') + Supported Compression Formats ----------------------------- @@ -212,6 +221,7 @@ Transport-specific Options - HTTP, HTTPS (read-only) - SSH, SCP and SFTP - WebHDFS +- GCS Each option involves setting up its own set of parameters. For example, for accessing S3, you often need to set up authentication, like API keys or a profile name. diff --git a/integration-tests/test_gcs.py b/integration-tests/test_gcs.py new file mode 100644 index 00000000..ebf1f0e3 --- /dev/null +++ b/integration-tests/test_gcs.py @@ -0,0 +1,123 @@ +# -*- coding: utf-8 -*- +import io +import os + +import google.cloud.storage +from six.moves.urllib import parse as urlparse + +import smart_open + +_GCS_URL = os.environ.get('SO_GCS_URL') +assert _GCS_URL is not None, 'please set the SO_GCS_URL environment variable' + + +def initialize_bucket(): + client = google.cloud.storage.Client() + parsed = urlparse.urlparse(_GCS_URL) + bucket_name = parsed.netloc + prefix = parsed.path + bucket = client.get_bucket(bucket_name) + blobs = bucket.list_blobs(prefix=prefix) + for blob in blobs: + blob.delete() + + +def write_read(key, content, write_mode, read_mode, **kwargs): + with smart_open.open(key, write_mode, **kwargs) as fout: + fout.write(content) + with smart_open.open(key, read_mode, **kwargs) as fin: + return fin.read() + + +def read_length_prefixed_messages(key, read_mode, **kwargs): + result = io.BytesIO() + + with smart_open.open(key, read_mode, **kwargs) as fin: + length_byte = fin.read(1) + while len(length_byte): + result.write(length_byte) + msg = fin.read(ord(length_byte)) + result.write(msg) + length_byte = fin.read(1) + return result.getvalue() + + +def test_gcs_readwrite_text(benchmark): + initialize_bucket() + + key = _GCS_URL + '/sanity.txt' + text = 'с гранатою в кармане, с чекою в руке' + actual = benchmark(write_read, key, text, 'w', 'r', encoding='utf-8') + assert actual == text + + +def test_gcs_readwrite_text_gzip(benchmark): + initialize_bucket() + + key = _GCS_URL + '/sanity.txt.gz' + text = 'не чайки здесь запели на знакомом языке' + actual = benchmark(write_read, key, text, 'w', 'r', encoding='utf-8') + assert actual == text + + +def test_gcs_readwrite_binary(benchmark): + initialize_bucket() + + key = _GCS_URL + '/sanity.txt' + binary = b'this is a test' + actual = benchmark(write_read, key, binary, 'wb', 'rb') + assert actual == binary + + +def test_gcs_readwrite_binary_gzip(benchmark): + initialize_bucket() + + key = _GCS_URL + '/sanity.txt.gz' + binary = b'this is a test' + actual = benchmark(write_read, key, binary, 'wb', 'rb') + assert actual == binary + + +def test_gcs_performance(benchmark): + initialize_bucket() + + one_megabyte = io.BytesIO() + for _ in range(1024*128): + one_megabyte.write(b'01234567') + one_megabyte = one_megabyte.getvalue() + + key = _GCS_URL + '/performance.txt' + actual = benchmark(write_read, key, one_megabyte, 'wb', 'rb') + assert actual == one_megabyte + + +def test_gcs_performance_gz(benchmark): + initialize_bucket() + + one_megabyte = io.BytesIO() + for _ in range(1024*128): + one_megabyte.write(b'01234567') + one_megabyte = one_megabyte.getvalue() + + key = _GCS_URL + '/performance.txt.gz' + actual = benchmark(write_read, key, one_megabyte, 'wb', 'rb') + assert actual == one_megabyte + + +def test_gcs_performance_small_reads(benchmark): + initialize_bucket() + + ONE_MIB = 1024**2 + one_megabyte_of_msgs = io.BytesIO() + msg = b'\x0f' + b'0123456789abcde' # a length-prefixed "message" + for _ in range(0, ONE_MIB, len(msg)): + one_megabyte_of_msgs.write(msg) + one_megabyte_of_msgs = one_megabyte_of_msgs.getvalue() + + key = _GCS_URL + '/many_reads_performance.bin' + + with smart_open.open(key, 'wb') as fout: + fout.write(one_megabyte_of_msgs) + + actual = benchmark(read_length_prefixed_messages, key, 'rb', buffering=ONE_MIB) + assert actual == one_megabyte_of_msgs diff --git a/setup.py b/setup.py index 9acc9281..73b83341 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,7 @@ def read(fname): 'boto >= 2.32', 'requests', 'boto3', + 'google-cloud-storage', ] if sys.version_info[0] == 2: install_requires.append('bz2file') @@ -66,7 +67,7 @@ def read(fname): setup( name='smart_open', version=__version__, - description='Utils for streaming large files (S3, HDFS, gzip, bz2...)', + description='Utils for streaming large files (S3, HDFS, GCS, gzip, bz2...)', long_description=read('README.rst'), packages=find_packages(), @@ -82,7 +83,7 @@ def read(fname): url='https://github.com/piskvorky/smart_open', download_url='http://pypi.python.org/pypi/smart_open', - keywords='file streaming, s3, hdfs', + keywords='file streaming, s3, hdfs, gcs', license='MIT', platforms='any', diff --git a/smart_open/gcs.py b/smart_open/gcs.py new file mode 100644 index 00000000..dd33ae39 --- /dev/null +++ b/smart_open/gcs.py @@ -0,0 +1,542 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2019 Radim Rehurek +# +# This code is distributed under the terms and conditions +# from the MIT License (MIT). +# +"""Implements file-like objects for reading and writing to/from GCS.""" + +import io +import logging +import sys + +import google.cloud.exceptions +import google.cloud.storage +import google.auth.transport.requests as google_requests +import six + +import smart_open.bytebuffer +import smart_open.s3 + +logger = logging.getLogger(__name__) + +_READ_BINARY = 'rb' +_WRITE_BINARY = 'wb' + +_MODES = (_READ_BINARY, _WRITE_BINARY) +"""Allowed I/O modes for working with GCS.""" + +_BINARY_TYPES = (six.binary_type, bytearray) +"""Allowed binary buffer types for writing to the underlying GCS stream""" + +if sys.version_info >= (2, 7): + _BINARY_TYPES = (six.binary_type, bytearray, memoryview) + +_BINARY_NEWLINE = b'\n' + +_UNKNOWN_FILE_SIZE = '*' + +SUPPORTED_SCHEME = "gs" +"""Supported scheme for GCS""" + +_MIN_MIN_PART_SIZE = _REQUIRED_CHUNK_MULTIPLE = 256 * 1024 +"""Google requires you to upload in multiples of 256 KB, except for the last part.""" + +_DEFAULT_MIN_PART_SIZE = 50 * 1024**2 +"""Default minimum part size for GCS multipart uploads""" + +DEFAULT_BUFFER_SIZE = 256 * 1024 +"""Default buffer size for working with GCS""" + +START = 0 +"""Seek to the absolute start of a GCS file""" + +CURRENT = 1 +"""Seek relative to the current positive of a GCS file""" + +END = 2 +"""Seek relative to the end of a GCS file""" + +_WHENCE_CHOICES = (START, CURRENT, END) + +_SUCCESSFUL_STATUS_CODES = (200, 201) + + +def _make_range_string(start, stop=None, end=_UNKNOWN_FILE_SIZE): + # + # https://cloud.google.com/storage/docs/xml-api/resumable-upload#step_3upload_the_file_blocks + # + if stop is None: + return 'bytes %d-/%s' % (start, end) + return 'bytes %d-%d/%s' % (start, stop, end) + + +class UploadFailedError(Exception): + def __init__(self, message, status_code, text): + """Raise when a multi-part upload to GCS returns a failed response status code. + + Parameters + ---------- + message: str + The error message to display. + status_code: int + The status code returned from the upload response. + text: str + The text returned from the upload response. + + """ + super(UploadFailedError, self).__init__(message) + self.status_code = status_code + self.text = text + + +def open( + bucket_id, + blob_id, + mode, + buffer_size=DEFAULT_BUFFER_SIZE, + min_part_size=_MIN_MIN_PART_SIZE, + client=None, # type: google.cloud.storage.Client + ): + """Open an GCS blob for reading or writing. + + Parameters + ---------- + bucket_id: str + The name of the bucket this object resides in. + blob_id: str + The name of the blob within the bucket. + mode: str + The mode for opening the object. Must be either "rb" or "wb". + buffer_size: int, optional + The buffer size to use when performing I/O. For reading only. + min_part_size: int, optional + The minimum part size for multipart uploads. For writing only. + client: google.cloud.storage.Client, optional + The GCS client to use when working with google-cloud-storage. + + """ + if mode == _READ_BINARY: + return SeekableBufferedInputBase( + bucket_id, + blob_id, + buffer_size=buffer_size, + line_terminator=_BINARY_NEWLINE, + client=client, + ) + elif mode == _WRITE_BINARY: + return BufferedOutputBase( + bucket_id, + blob_id, + min_part_size=min_part_size, + client=client, + ) + else: + raise NotImplementedError('GCS support for mode %r not implemented' % mode) + + +class _SeekableRawReader(object): + """Read an GCS object.""" + + def __init__(self, gcs_blob, size): + # type: (google.cloud.storage.Blob, int) -> None + self._blob = gcs_blob + self._size = size + self._position = 0 + + def seek(self, position): + """Seek to the specified position (byte offset) in the GCS key. + + :param int position: The byte offset from the beginning of the key. + + Returns the position after seeking. + """ + self._position = position + return self._position + + def read(self, size=-1): + if self._position >= self._size: + return b'' + binary = self._download_blob_chunk(size) + self._position += len(binary) + return binary + + def _download_blob_chunk(self, size): + start = position = self._position + if position == self._size: + # + # When reading, we can't seek to the first byte of an empty file. + # Similarly, we can't seek past the last byte. Do nothing here. + # + binary = b'' + elif size == -1: + binary = self._blob.download_as_string(start=start) + else: + end = position + size + binary = self._blob.download_as_string(start=start, end=end) + return binary + + +class SeekableBufferedInputBase(io.BufferedIOBase): + """Reads bytes from GCS. + + Implements the io.BufferedIOBase interface of the standard library. + + :raises google.cloud.exceptions.NotFound: Raised when the blob to read from does not exist. + + """ + def __init__( + self, + bucket, + key, + buffer_size=DEFAULT_BUFFER_SIZE, + line_terminator=_BINARY_NEWLINE, + client=None, # type: google.cloud.storage.Client + ): + if client is None: + client = google.cloud.storage.Client() + bucket = client.get_bucket(bucket) # type: google.cloud.storage.Bucket + + self._blob = bucket.get_blob(key) + if self._blob is None: + raise google.cloud.exceptions.NotFound('blob {} not found in {}'.format(key, bucket)) + self._size = self._blob.size if self._blob.size is not None else 0 + + self._raw_reader = _SeekableRawReader(self._blob, self._size) + self._current_pos = 0 + self._current_part_size = buffer_size + self._current_part = smart_open.bytebuffer.ByteBuffer(buffer_size) + self._eof = False + self._line_terminator = line_terminator + + # + # This member is part of the io.BufferedIOBase interface. + # + self.raw = None + + # + # Override some methods from io.IOBase. + # + def close(self): + """Flush and close this stream.""" + logger.debug("close: called") + self._blob = None + self._current_part = None + self._raw_reader = None + + def readable(self): + """Return True if the stream can be read from.""" + return True + + def seekable(self): + """If False, seek(), tell() and truncate() will raise IOError. + + We offer only seek support, and no truncate support.""" + return True + + # + # io.BufferedIOBase methods. + # + def detach(self): + """Unsupported.""" + raise io.UnsupportedOperation + + def seek(self, offset, whence=START): + """Seek to the specified position. + + :param int offset: The offset in bytes. + :param int whence: Where the offset is from. + + Returns the position after seeking.""" + logger.debug('seeking to offset: %r whence: %r', offset, whence) + if whence not in _WHENCE_CHOICES: + raise ValueError('invalid whence, expected one of %r' % _WHENCE_CHOICES) + + if whence == START: + new_position = offset + elif whence == CURRENT: + new_position = self._current_pos + offset + else: + new_position = self._size + offset + new_position = smart_open.s3.clamp(new_position, 0, self._size) + self._current_pos = new_position + self._raw_reader.seek(new_position) + logger.debug('current_pos: %r', self._current_pos) + + self._current_part.empty() + self._eof = self._current_pos == self._size + return self._current_pos + + def tell(self): + """Return the current position within the file.""" + return self._current_pos + + def truncate(self, size=None): + """Unsupported.""" + raise io.UnsupportedOperation + + def read(self, size=-1): + """Read up to size bytes from the object and return them.""" + if size == 0: + return b'' + elif size < 0: + self._current_pos = self._size + return self._read_from_buffer() + self._raw_reader.read() + + # + # Return unused data first + # + if len(self._current_part) >= size: + return self._read_from_buffer(size) + + # + # If the stream is finished, return what we have. + # + if self._eof: + return self._read_from_buffer() + + # + # Fill our buffer to the required size. + # + self._fill_buffer(size) + return self._read_from_buffer(size) + + def read1(self, size=-1): + """This is the same as read().""" + return self.read(size=size) + + def readinto(self, b): + """Read up to len(b) bytes into b, and return the number of bytes + read.""" + data = self.read(len(b)) + if not data: + return 0 + b[:len(data)] = data + return len(data) + + def readline(self, limit=-1): + """Read up to and including the next newline. Returns the bytes read.""" + if limit != -1: + raise NotImplementedError('limits other than -1 not implemented yet') + the_line = io.BytesIO() + while not (self._eof and len(self._current_part) == 0): + # + # In the worst case, we're reading the unread part of self._current_part + # twice here, once in the if condition and once when calling index. + # + # This is sub-optimal, but better than the alternative: wrapping + # .index in a try..except, because that is slower. + # + remaining_buffer = self._current_part.peek() + if self._line_terminator in remaining_buffer: + next_newline = remaining_buffer.index(self._line_terminator) + the_line.write(self._read_from_buffer(next_newline + 1)) + break + else: + the_line.write(self._read_from_buffer()) + self._fill_buffer() + return the_line.getvalue() + + # + # Internal methods. + # + def _read_from_buffer(self, size=-1): + """Remove at most size bytes from our buffer and return them.""" + # logger.debug('reading %r bytes from %r byte-long buffer', size, len(self._current_part)) + size = size if size >= 0 else len(self._current_part) + part = self._current_part.read(size) + self._current_pos += len(part) + # logger.debug('part: %r', part) + return part + + def _fill_buffer(self, size=-1): + size = size if size >= 0 else self._current_part._chunk_size + while len(self._current_part) < size and not self._eof: + bytes_read = self._current_part.fill(self._raw_reader) + if bytes_read == 0: + logger.debug('reached EOF while filling buffer') + self._eof = True + + def __str__(self): + return "(%s, %r, %r)" % (self.__class__.__name__, self._bucket.name, self._blob.name) + + def __repr__(self): + return ( + "%s(" + "bucket=%r, " + "blob=%r, " + "buffer_size=%r)" + ) % ( + self.__class__.__name__, + self._bucket.name, + self._blob.name, + self._current_part_size, + ) + + +class BufferedOutputBase(io.BufferedIOBase): + """Writes bytes to GCS. + + Implements the io.BufferedIOBase interface of the standard library.""" + + def __init__( + self, + bucket, + blob, + min_part_size=_DEFAULT_MIN_PART_SIZE, + client=None, # type: google.cloud.storage.Client + ): + if client is None: + client = google.cloud.storage.Client() + self._client = client + self._credentials = self._client._credentials # noqa + self._bucket = self._client.bucket(bucket) # type: google.cloud.storage.Bucket + self._blob = self._bucket.blob(blob) # type: google.cloud.storage.Blob + assert min_part_size % _REQUIRED_CHUNK_MULTIPLE == 0, 'min part size must be a multiple of 256KB' + assert min_part_size >= _MIN_MIN_PART_SIZE, 'min part size must be greater than 256KB' + self._min_part_size = min_part_size + + self._total_size = 0 + self._total_parts = 0 + self._current_part = io.BytesIO() + + self._session = google_requests.AuthorizedSession(self._credentials) + + # + # https://cloud.google.com/storage/docs/json_api/v1/how-tos/resumable-upload#start-resumable + # + self._resumable_upload_url = self._blob.create_resumable_upload_session() + + # + # This member is part of the io.BufferedIOBase interface. + # + self.raw = None + + def flush(self): + pass + + # + # Override some methods from io.IOBase. + # + def close(self): + logger.debug("closing") + if self._total_size == 0: # empty files + self._upload_empty_part() + if self._current_part.tell(): + self._upload_next_part() + logger.debug("successfully closed") + + def writable(self): + """Return True if the stream supports writing.""" + return True + + def tell(self): + """Return the current stream position.""" + return self._total_size + + # + # io.BufferedIOBase methods. + # + def detach(self): + raise io.UnsupportedOperation("detach() not supported") + + def write(self, b): + """Write the given bytes (binary string) to the GCS file. + + There's buffering happening under the covers, so this may not actually + do any HTTP transfer right away.""" + + if not isinstance(b, _BINARY_TYPES): + raise TypeError("input must be one of %r, got: %r" % (_BINARY_TYPES, type(b))) + + self._current_part.write(b) + self._total_size += len(b) + + if self._current_part.tell() >= self._min_part_size: + self._upload_next_part() + + return len(b) + + def terminate(self): + """Cancel the underlying resumable upload.""" + # + # https://cloud.google.com/storage/docs/xml-api/resumable-upload#example_cancelling_an_upload + # + self._session.delete(self._resumable_upload_url) + + # + # Internal methods. + # + def _upload_next_part(self): + part_num = self._total_parts + 1 + logger.info( + "uploading part #%i, %i bytes (total %.3fGB)", + part_num, + self._current_part.tell(), + self._total_size / 1024.0 ** 3 + ) + content_length = end = self._current_part.tell() + start = self._total_size - content_length + stop = self._total_size - 1 + + self._current_part.seek(0) + + headers = { + 'Content-Length': str(content_length), + 'Content-Range': _make_range_string(start, stop, end) + } + response = self._session.put(self._resumable_upload_url, data=self._current_part, headers=headers) + + if response.status_code not in _SUCCESSFUL_STATUS_CODES: + msg = ( + "upload failed (" + "status code: %i" + "response text=%s, " + "part #%i, " + "%i bytes (total %.3fGB)" + ) % ( + response.status_code, + response.text, + part_num, + self._current_part.tell(), + self._total_size / 1024.0 ** 3, + ) + raise UploadFailedError(msg, response.status_code, response.text) + logger.debug("upload of part #%i finished" % part_num) + + self._total_parts += 1 + self._current_part = io.BytesIO() + + def _upload_empty_part(self): + logger.debug("creating empty file") + headers = {'Content-Length': '0'} + response = self._session.put(self._resumable_upload_url, headers=headers) + assert response.status_code in _SUCCESSFUL_STATUS_CODES + + self._total_parts += 1 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is not None: + self.terminate() + else: + self.close() + + def __str__(self): + return "(%s, %r, %r)" % (self.__class__.__name__, self._bucket.name, self._blob.name) + + def __repr__(self): + return ( + "%s(" + "bucket=%r, " + "blob=%r, " + "min_part_size=%r)" + ) % ( + self.__class__.__name__, + self._bucket.name, + self._blob.name, + self._min_part_size, + ) diff --git a/smart_open/smart_open_lib.py b/smart_open/smart_open_lib.py index b39aecc3..037a07f1 100644 --- a/smart_open/smart_open_lib.py +++ b/smart_open/smart_open_lib.py @@ -41,6 +41,7 @@ import smart_open.webhdfs as smart_open_webhdfs import smart_open.http as smart_open_http import smart_open.ssh as smart_open_ssh +import smart_open.gcs as smart_open_gcs from smart_open import doctools @@ -122,6 +123,7 @@ def _handle_gzip(file_obj, mode): 'uri_path', 'bucket_id', 'key_id', + 'blob_id', 'port', 'host', 'ordinary_calling_format', @@ -134,7 +136,7 @@ def _handle_gzip(file_obj, mode): """Represents all the options that we parse from user input. Some of the above options only make sense for certain protocols, e.g. -bucket_id is only for S3. +bucket_id is only for S3 and GCS. """ # # Set the default values for all Uri fields to be None. This allows us to only @@ -380,6 +382,10 @@ def open( doctools.extract_kwargs(smart_open_ssh.open.__doc__), lpad=u' ', ), + 'gcs': doctools.to_docstring( + doctools.extract_kwargs(smart_open_gcs.open.__doc__), + lpad=u' ', + ), 'examples': doctools.extract_examples_from_readme_rst(), } @@ -575,6 +581,9 @@ def _open_binary_stream(uri, mode, transport_params): filename = P.basename(urlparse.urlparse(uri).path) kw = _check_kwargs(smart_open_http.open, transport_params) return smart_open_http.open(uri, mode, **kw), filename + elif parsed_uri.scheme == smart_open_gcs.SUPPORTED_SCHEME: + kw = _check_kwargs(smart_open_gcs.open, transport_params) + return smart_open_gcs.open(parsed_uri.bucket_id, parsed_uri.blob_id, mode, **kw), filename else: raise NotImplementedError("scheme %r is not supported", parsed_uri.scheme) elif hasattr(uri, 'read'): @@ -684,6 +693,7 @@ def _parse_uri(uri_as_string): Supported URI schemes are: * file + * gs * hdfs * http * https @@ -711,6 +721,7 @@ def _parse_uri(uri_as_string): * file:///home/user/file.bz2 * [ssh|scp|sftp]://username@host//path/file * [ssh|scp|sftp]://username@host/path/file + * gs://my_bucket/my_blob """ if os.name == 'nt': @@ -735,6 +746,8 @@ def _parse_uri(uri_as_string): return Uri(scheme=parsed_uri.scheme, uri_path=uri_as_string) elif parsed_uri.scheme in smart_open_ssh.SCHEMES: return _parse_uri_ssh(parsed_uri) + elif parsed_uri.scheme == smart_open_gcs.SUPPORTED_SCHEME: + return _parse_uri_gcs(parsed_uri) else: raise NotImplementedError( "unknown URI scheme %r in %r" % (parsed_uri.scheme, uri_as_string) @@ -831,6 +844,13 @@ def _unquote(text): return text and urlparse.unquote(text) +def _parse_uri_gcs(parsed_uri): + assert parsed_uri.scheme == smart_open_gcs.SUPPORTED_SCHEME + bucket_id, blob_id = parsed_uri.netloc, parsed_uri.path[1:] + + return Uri(scheme=parsed_uri.scheme, bucket_id=bucket_id, blob_id=blob_id) + + def _need_to_buffer(file_obj, mode, ext): """Returns True if we need to buffer the whole file in memory in order to proceed.""" try: diff --git a/smart_open/tests/test_gcs.py b/smart_open/tests/test_gcs.py new file mode 100644 index 00000000..df8c4c06 --- /dev/null +++ b/smart_open/tests/test_gcs.py @@ -0,0 +1,829 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2019 Radim Rehurek +# +# This code is distributed under the terms and conditions +# from the MIT License (MIT). +# +import gzip +import inspect +import io +import logging +import os +import time +import uuid +import unittest +try: + from unittest import mock +except ImportError: + import mock +import warnings +from collections import OrderedDict + +import google.cloud +import google.api_core.exceptions +import six + +import smart_open + +BUCKET_NAME = 'test-smartopen-{}'.format(uuid.uuid4().hex) +BLOB_NAME = 'test-blob' +WRITE_BLOB_NAME = 'test-write-blob' +DISABLE_MOCKS = os.environ.get('SO_DISABLE_MOCKS') == "1" + +RESUMABLE_SESSION_URI_TEMPLATE = ( + 'https://www.googleapis.com/upload/storage/v1/b/' + '%(bucket)s' + '/o?uploadType=resumable&upload_id=' + '%(upload_id)s' +) + +logger = logging.getLogger(__name__) + + +def ignore_resource_warnings(): + if six.PY2: + return + warnings.filterwarnings("ignore", category=ResourceWarning, message="unclosed.*") # noqa + + +class FakeBucket(object): + def __init__(self, client, name=None): + self.client = client # type: FakeClient + self.name = name + self.blobs = OrderedDict() + self._exists = True + + # + # This is simpler than creating a backend and metaclass to store the state of every bucket created + # + self.client.register_bucket(self) + + def blob(self, blob_id): + return self.blobs.get(blob_id, FakeBlob(blob_id, self)) + + def delete(self): + self.client.delete_bucket(self) + self._exists = False + for blob in list(self.blobs.values()): + blob.delete() + + def exists(self): + return self._exists + + def get_blob(self, blob_id): + try: + return self.blobs[blob_id] + except KeyError: + raise google.cloud.exceptions.NotFound('Blob {} not found'.format(blob_id)) + + def list_blobs(self): + return list(self.blobs.values()) + + def delete_blob(self, blob): + del self.blobs[blob.name] + + def register_blob(self, blob): + if blob.name not in self.blobs.keys(): + self.blobs[blob.name] = blob + + def register_upload(self, upload): + self.client.register_upload(upload) + + +class FakeBucketTest(unittest.TestCase): + def setUp(self): + self.client = FakeClient() + self.bucket = FakeBucket(self.client, 'test-bucket') + + def test_blob_registers_with_bucket(self): + blob_id = 'blob.txt' + expected = FakeBlob(blob_id, self.bucket) + actual = self.bucket.blob(blob_id) + self.assertEqual(actual, expected) + + def test_blob_alternate_constuctor(self): + blob_id = 'blob.txt' + expected = self.bucket.blob(blob_id) + actual = self.bucket.list_blobs()[0] + self.assertEqual(actual, expected) + + def test_delete(self): + blob_id = 'blob.txt' + blob = FakeBlob(blob_id, self.bucket) + self.bucket.delete() + self.assertFalse(self.bucket.exists()) + self.assertFalse(blob.exists()) + + def test_get_multiple_blobs(self): + blob_one_id = 'blob_one.avro' + blob_two_id = 'blob_two.parquet' + blob_one = self.bucket.blob(blob_one_id) + blob_two = self.bucket.blob(blob_two_id) + actual_first_blob = self.bucket.get_blob(blob_one_id) + actual_second_blob = self.bucket.get_blob(blob_two_id) + self.assertEqual(actual_first_blob, blob_one) + self.assertEqual(actual_second_blob, blob_two) + + def test_get_nonexistent_blob(self): + with self.assertRaises(google.cloud.exceptions.NotFound): + self.bucket.get_blob('test-blob') + + def test_list_blobs(self): + blob_one = self.bucket.blob('blob_one.avro') + blob_two = self.bucket.blob('blob_two.parquet') + actual = self.bucket.list_blobs() + expected = [blob_one, blob_two] + self.assertEqual(actual, expected) + + +class FakeBlob(object): + def __init__(self, name, bucket, create=True): + self.name = name + self._bucket = bucket # type: FakeBucket + self._exists = False + self.__contents = io.BytesIO() + + if create: + self._create_if_not_exists() + + def create_resumable_upload_session(self): + resumeable_upload_url = RESUMABLE_SESSION_URI_TEMPLATE % dict( + bucket=self._bucket.name, + upload_id=str(uuid.uuid4()), + ) + upload = FakeBlobUpload(resumeable_upload_url, self) + self._bucket.register_upload(upload) + return resumeable_upload_url + + def delete(self): + self._bucket.delete_blob(self) + self._exists = False + + def download_as_string(self, start=None, end=None): + if start is None: + start = 0 + if end is None: + end = self.__contents.tell() + self.__contents.seek(start) + return self.__contents.read(end - start) + + def exists(self, client=None): + return self._exists + + def upload_from_string(self, str_): + self.__contents.write(str_) + + def write(self, data): + self.upload_from_string(data) + + @property + def bucket(self): + return self._bucket + + @property + def size(self): + if self.__contents.tell() == 0: + return None + return self.__contents.tell() + + def _create_if_not_exists(self): + self._bucket.register_blob(self) + self._exists = True + + +class FakeBlobTest(unittest.TestCase): + def setUp(self): + self.client = FakeClient() + self.bucket = FakeBucket(self.client, 'test-bucket') + + def test_create_resumable_upload_session(self): + blob = FakeBlob('fake-blob', self.bucket) + resumable_upload_url = blob.create_resumable_upload_session() + self.assertTrue(resumable_upload_url in self.client.uploads) + + def test_delete(self): + blob = FakeBlob('fake-blob', self.bucket) + blob.delete() + self.assertFalse(blob.exists()) + self.assertEqual(self.bucket.list_blobs(), []) + + def test_upload_download(self): + blob = FakeBlob('fake-blob', self.bucket) + contents = b'test' + blob.upload_from_string(contents) + self.assertEqual(blob.download_as_string(), b'test') + self.assertEqual(blob.download_as_string(start=2), b'st') + self.assertEqual(blob.download_as_string(end=2), b'te') + self.assertEqual(blob.download_as_string(start=2, end=3), b's') + + def test_size(self): + blob = FakeBlob('fake-blob', self.bucket) + self.assertEqual(blob.size, None) + blob.upload_from_string(b'test') + self.assertEqual(blob.size, 4) + + +class FakeCredentials(object): + def __init__(self, client): + self.client = client # type: FakeClient + + def before_request(self, *args, **kwargs): + pass + + +class FakeClient(object): + def __init__(self, credentials=None): + if credentials is None: + credentials = FakeCredentials(self) + self._credentials = credentials # type: FakeCredentials + self.uploads = OrderedDict() + self.__buckets = OrderedDict() + + def bucket(self, bucket_id): + try: + return self.__buckets[bucket_id] + except KeyError: + raise google.cloud.exceptions.NotFound('Bucket %s not found' % bucket_id) + + def create_bucket(self, bucket_id): + bucket = FakeBucket(self, bucket_id) + return bucket + + def get_bucket(self, bucket_id): + return self.bucket(bucket_id) + + def register_bucket(self, bucket): + if bucket.name in self.__buckets: + raise google.cloud.exceptions.Conflict('Bucket %s already exists' % bucket.name) + self.__buckets[bucket.name] = bucket + + def delete_bucket(self, bucket): + del self.__buckets[bucket.name] + + def register_upload(self, upload): + self.uploads[upload.url] = upload + + +class FakeClientTest(unittest.TestCase): + def setUp(self): + self.client = FakeClient() + + def test_nonexistent_bucket(self): + with self.assertRaises(google.cloud.exceptions.NotFound): + self.client.bucket('test-bucket') + + def test_bucket(self): + bucket_id = 'test-bucket' + bucket = FakeBucket(self.client, bucket_id) + actual = self.client.bucket(bucket_id) + self.assertEqual(actual, bucket) + + def test_duplicate_bucket(self): + bucket_id = 'test-bucket' + FakeBucket(self.client, bucket_id) + with self.assertRaises(google.cloud.exceptions.Conflict): + FakeBucket(self.client, bucket_id) + + def test_create_bucket(self): + bucket_id = 'test-bucket' + bucket = self.client.create_bucket(bucket_id) + actual = self.client.get_bucket(bucket_id) + self.assertEqual(actual, bucket) + + +class FakeBlobUpload(object): + def __init__(self, url, blob): + self.url = url + self.blob = blob # type: FakeBlob + self.__contents = io.BytesIO() + + def write(self, data): + self.__contents.write(data) + + def finish(self): + self.__contents.seek(0) + data = self.__contents.read() + self.blob.upload_from_string(data) + + def terminate(self): + self.blob.delete() + self.__contents = None + + +class FakeResponse(object): + def __init__(self, status_code=200): + self.status_code = status_code + + +class FakeAuthorizedSession(object): + def __init__(self, credentials): + self._credentials = credentials # type: FakeCredentials + + def delete(self, upload_url): + upload = self._credentials.client.uploads.pop(upload_url) + upload.terminate() + + def put(self, url, data=None, headers=None): + if data is not None: + upload = self._credentials.client.uploads[url] + upload.write(data.read()) + if not headers['Content-Range'].endswith(smart_open.gcs._UNKNOWN_FILE_SIZE): + upload.finish() + return FakeResponse() + + @staticmethod + def _blob_with_url(url, client): + # type: (str, FakeClient) -> FakeBlobUpload + return client.uploads.get(url) + + +class FakeAuthorizedSessionTest(unittest.TestCase): + def setUp(self): + self.client = FakeClient() + self.credentials = FakeCredentials(self.client) + self.session = FakeAuthorizedSession(self.credentials) + self.bucket = FakeBucket(self.client, 'test-bucket') + self.blob = FakeBlob('test-blob', self.bucket) + self.upload_url = self.blob.create_resumable_upload_session() + + def test_delete(self): + self.session.delete(self.upload_url) + self.assertFalse(self.blob.exists()) + self.assertDictEqual(self.client.uploads, {}) + + def test_unfinished_put_does_not_write_to_blob(self): + data = io.BytesIO(b'test') + headers = { + 'Content-Range': 'bytes 0-3/*', + 'Content-Length': str(4), + } + response = self.session.put(self.upload_url, data, headers=headers) + self.assertEqual(response.status_code, 200) + self.session._blob_with_url(self.upload_url, self.client) + blob_contents = self.blob.download_as_string() + self.assertEqual(blob_contents, b'') + + def test_finished_put_writes_to_blob(self): + data = io.BytesIO(b'test') + headers = { + 'Content-Range': 'bytes 0-3/4', + 'Content-Length': str(4), + } + response = self.session.put(self.upload_url, data, headers=headers) + self.assertEqual(response.status_code, 200) + self.session._blob_with_url(self.upload_url, self.client) + blob_contents = self.blob.download_as_string() + data.seek(0) + self.assertEqual(blob_contents, data.read()) + + +if DISABLE_MOCKS: + storage_client = google.cloud.storage.Client() +else: + storage_client = FakeClient() + + +def get_bucket(): + return storage_client.bucket(BUCKET_NAME) + + +def get_blob(): + bucket = get_bucket() + return bucket.blob(BLOB_NAME) + + +def cleanup_bucket(): + bucket = get_bucket() + + blobs = bucket.list_blobs() + for blob in blobs: + blob.delete() + + +def put_to_bucket(contents, num_attempts=12, sleep_time=5): + logger.debug('%r', locals()) + + # + # In real life, it can take a few seconds for the bucket to become ready. + # If we try to write to the key while the bucket while it isn't ready, we + # will get a StorageError: NotFound. + # + for attempt in range(num_attempts): + try: + blob = get_blob() + blob.upload_from_string(contents) + return + except google.cloud.exceptions.NotFound as err: + logger.error('caught %r, retrying', err) + time.sleep(sleep_time) + + assert False, 'failed to create bucket %s after %d attempts' % (BUCKET_NAME, num_attempts) + + +def mock_gcs(class_or_func): + """Mock all methods of a class or a function.""" + if inspect.isclass(class_or_func): + for attr in class_or_func.__dict__: + if callable(getattr(class_or_func, attr)): + setattr(class_or_func, attr, mock_gcs_func(getattr(class_or_func, attr))) + return class_or_func + else: + return mock_gcs_func(class_or_func) + + +def mock_gcs_func(func): + """Mock the function and provide additional required arguments.""" + def inner(*args, **kwargs): + with mock.patch('google.cloud.storage.Client', return_value=storage_client), \ + mock.patch( + 'smart_open.gcs.google_requests.AuthorizedSession', + return_value=FakeAuthorizedSession(storage_client._credentials), + ): + assert callable(func), 'you didn\'t provide a function!' + try: # is it a method that needs a self arg? + self_arg = inspect.signature(func).self + func(self_arg, *args, **kwargs) + except AttributeError: + func(*args, **kwargs) + return inner + + +def maybe_mock_gcs(func): + if DISABLE_MOCKS: + return func + else: + return mock_gcs(func) + + +@maybe_mock_gcs +def setUpModule(): # noqa + """Called once by unittest when initializing this module. Set up the + test GCS bucket. + """ + storage_client.create_bucket(BUCKET_NAME) + + +@maybe_mock_gcs +def tearDownModule(): # noqa + """Called once by unittest when tearing down this module. Empty and + removes the test GCS bucket. + """ + try: + bucket = get_bucket() + bucket.delete() + except google.cloud.exceptions.NotFound: + pass + + +@maybe_mock_gcs +class SeekableBufferedInputBaseTest(unittest.TestCase): + def setUp(self): + # lower the multipart upload size, to speed up these tests + self.old_min_buffer_size = smart_open.gcs.DEFAULT_BUFFER_SIZE + smart_open.gcs.DEFAULT_BUFFER_SIZE = 5 * 1024**2 + + ignore_resource_warnings() + + def tearDown(self): + cleanup_bucket() + + def test_iter(self): + """Are GCS files iterated over correctly?""" + expected = u"hello wořld\nhow are you?".encode('utf8') + put_to_bucket(contents=expected) + + # connect to fake GCS and read from the fake key we filled above + fin = smart_open.gcs.SeekableBufferedInputBase(BUCKET_NAME, BLOB_NAME) + output = [line.rstrip(b'\n') for line in fin] + self.assertEqual(output, expected.split(b'\n')) + + def test_iter_context_manager(self): + # same thing but using a context manager + expected = u"hello wořld\nhow are you?".encode('utf8') + put_to_bucket(contents=expected) + with smart_open.gcs.SeekableBufferedInputBase(BUCKET_NAME, BLOB_NAME) as fin: + output = [line.rstrip(b'\n') for line in fin] + self.assertEqual(output, expected.split(b'\n')) + + def test_read(self): + """Are GCS files read correctly?""" + content = u"hello wořld\nhow are you?".encode('utf8') + put_to_bucket(contents=content) + logger.debug('content: %r len: %r', content, len(content)) + + fin = smart_open.gcs.SeekableBufferedInputBase(BUCKET_NAME, BLOB_NAME) + self.assertEqual(content[:6], fin.read(6)) + self.assertEqual(content[6:14], fin.read(8)) # ř is 2 bytes + self.assertEqual(content[14:], fin.read()) # read the rest + + def test_seek_beginning(self): + """Does seeking to the beginning of GCS files work correctly?""" + content = u"hello wořld\nhow are you?".encode('utf8') + put_to_bucket(contents=content) + + fin = smart_open.gcs.SeekableBufferedInputBase(BUCKET_NAME, BLOB_NAME) + self.assertEqual(content[:6], fin.read(6)) + self.assertEqual(content[6:14], fin.read(8)) # ř is 2 bytes + + fin.seek(0) + self.assertEqual(content, fin.read()) # no size given => read whole file + + fin.seek(0) + self.assertEqual(content, fin.read(-1)) # same thing + + def test_seek_start(self): + """Does seeking from the start of GCS files work correctly?""" + content = u"hello wořld\nhow are you?".encode('utf8') + put_to_bucket(contents=content) + + fin = smart_open.gcs.SeekableBufferedInputBase(BUCKET_NAME, BLOB_NAME) + seek = fin.seek(6) + self.assertEqual(seek, 6) + self.assertEqual(fin.tell(), 6) + self.assertEqual(fin.read(6), u'wořld'.encode('utf-8')) + + def test_seek_current(self): + """Does seeking from the middle of GCS files work correctly?""" + content = u"hello wořld\nhow are you?".encode('utf8') + put_to_bucket(contents=content) + + fin = smart_open.gcs.SeekableBufferedInputBase(BUCKET_NAME, BLOB_NAME) + self.assertEqual(fin.read(5), b'hello') + seek = fin.seek(1, whence=smart_open.gcs.CURRENT) + self.assertEqual(seek, 6) + self.assertEqual(fin.read(6), u'wořld'.encode('utf-8')) + + def test_seek_end(self): + """Does seeking from the end of GCS files work correctly?""" + content = u"hello wořld\nhow are you?".encode('utf8') + put_to_bucket(contents=content) + + fin = smart_open.gcs.SeekableBufferedInputBase(BUCKET_NAME, BLOB_NAME) + seek = fin.seek(-4, whence=smart_open.gcs.END) + self.assertEqual(seek, len(content) - 4) + self.assertEqual(fin.read(), b'you?') + + def test_detect_eof(self): + content = u"hello wořld\nhow are you?".encode('utf8') + put_to_bucket(contents=content) + + fin = smart_open.gcs.SeekableBufferedInputBase(BUCKET_NAME, BLOB_NAME) + fin.read() + eof = fin.tell() + self.assertEqual(eof, len(content)) + fin.seek(0, whence=smart_open.gcs.END) + self.assertEqual(eof, fin.tell()) + + def test_read_gzip(self): + expected = u'раcцветали яблони и груши, поплыли туманы над рекой...'.encode('utf-8') + buf = io.BytesIO() + buf.close = lambda: None # keep buffer open so that we can .getvalue() + with gzip.GzipFile(fileobj=buf, mode='w') as zipfile: + zipfile.write(expected) + put_to_bucket(contents=buf.getvalue()) + + # + # Make sure we're reading things correctly. + # + with smart_open.gcs.SeekableBufferedInputBase(BUCKET_NAME, BLOB_NAME) as fin: + self.assertEqual(fin.read(), buf.getvalue()) + + # + # Make sure the buffer we wrote is legitimate gzip. + # + sanity_buf = io.BytesIO(buf.getvalue()) + with gzip.GzipFile(fileobj=sanity_buf) as zipfile: + self.assertEqual(zipfile.read(), expected) + + logger.debug('starting actual test') + with smart_open.gcs.SeekableBufferedInputBase(BUCKET_NAME, BLOB_NAME) as fin: + with gzip.GzipFile(fileobj=fin) as zipfile: + actual = zipfile.read() + + self.assertEqual(expected, actual) + + def test_readline(self): + content = b'englishman\nin\nnew\nyork\n' + put_to_bucket(contents=content) + + with smart_open.gcs.SeekableBufferedInputBase(BUCKET_NAME, BLOB_NAME) as fin: + fin.readline() + self.assertEqual(fin.tell(), content.index(b'\n')+1) + + fin.seek(0) + actual = list(fin) + self.assertEqual(fin.tell(), len(content)) + + expected = [b'englishman\n', b'in\n', b'new\n', b'york\n'] + self.assertEqual(expected, actual) + + def test_readline_tiny_buffer(self): + content = b'englishman\nin\nnew\nyork\n' + put_to_bucket(contents=content) + + with smart_open.gcs.SeekableBufferedInputBase(BUCKET_NAME, BLOB_NAME, buffer_size=8) as fin: + actual = list(fin) + + expected = [b'englishman\n', b'in\n', b'new\n', b'york\n'] + self.assertEqual(expected, actual) + + def test_read0_does_not_return_data(self): + content = b'englishman\nin\nnew\nyork\n' + put_to_bucket(contents=content) + + with smart_open.gcs.SeekableBufferedInputBase(BUCKET_NAME, BLOB_NAME) as fin: + data = fin.read(0) + + self.assertEqual(data, b'') + + def test_read_past_end(self): + content = b'englishman\nin\nnew\nyork\n' + put_to_bucket(contents=content) + + with smart_open.gcs.SeekableBufferedInputBase(BUCKET_NAME, BLOB_NAME) as fin: + data = fin.read(100) + + self.assertEqual(data, content) + + +@maybe_mock_gcs +class BufferedOutputBaseTest(unittest.TestCase): + """ + Test writing into GCS files. + + """ + def setUp(self): + ignore_resource_warnings() + + def tearDown(self): + cleanup_bucket() + + def test_write_01(self): + """Does writing into GCS work correctly?""" + test_string = u"žluťoučký koníček".encode('utf8') + + with smart_open.gcs.BufferedOutputBase(BUCKET_NAME, WRITE_BLOB_NAME) as fout: + fout.write(test_string) + + output = list(smart_open.open("gs://{}/{}".format(BUCKET_NAME, WRITE_BLOB_NAME), "rb")) + + self.assertEqual(output, [test_string]) + + def test_write_01a(self): + """Does gcs write fail on incorrect input?""" + try: + with smart_open.gcs.BufferedOutputBase(BUCKET_NAME, WRITE_BLOB_NAME) as fin: + fin.write(None) + except TypeError: + pass + else: + self.fail() + + def test_write_02(self): + """Does gcs write unicode-utf8 conversion work?""" + smart_open_write = smart_open.gcs.BufferedOutputBase(BUCKET_NAME, WRITE_BLOB_NAME) + smart_open_write.tell() + logger.info("smart_open_write: %r", smart_open_write) + with smart_open_write as fout: + fout.write(u"testžížáč".encode("utf-8")) + self.assertEqual(fout.tell(), 14) + + def test_write_03(self): + """Does gcs multipart chunking work correctly?""" + # write + smart_open_write = smart_open.gcs.BufferedOutputBase( + BUCKET_NAME, WRITE_BLOB_NAME, min_part_size=256 * 1024 + ) + with smart_open_write as fout: + fout.write(b"t" * 262141) + self.assertEqual(fout._current_part.tell(), 262141) + + fout.write(b"t\n") + self.assertEqual(fout._current_part.tell(), 262143) + self.assertEqual(fout._total_parts, 0) + + fout.write(b"t") + self.assertEqual(fout._current_part.tell(), 0) + self.assertEqual(fout._total_parts, 1) + + # read back the same key and check its content + output = list(smart_open.open("gs://{}/{}".format(BUCKET_NAME, WRITE_BLOB_NAME))) + self.assertEqual(output, ["t" * 262142 + '\n', "t"]) + + def test_write_04(self): + """Does writing no data cause key with an empty value to be created?""" + smart_open_write = smart_open.gcs.BufferedOutputBase(BUCKET_NAME, WRITE_BLOB_NAME) + with smart_open_write as fout: # noqa + pass + + # read back the same key and check its content + output = list(smart_open.open("gs://{}/{}".format(BUCKET_NAME, WRITE_BLOB_NAME))) + + self.assertEqual(output, []) + + def test_gzip(self): + expected = u'а не спеть ли мне песню... о любви'.encode('utf-8') + with smart_open.gcs.BufferedOutputBase(BUCKET_NAME, WRITE_BLOB_NAME) as fout: + with gzip.GzipFile(fileobj=fout, mode='w') as zipfile: + zipfile.write(expected) + + with smart_open.gcs.SeekableBufferedInputBase(BUCKET_NAME, WRITE_BLOB_NAME) as fin: + with gzip.GzipFile(fileobj=fin) as zipfile: + actual = zipfile.read() + + self.assertEqual(expected, actual) + + def test_buffered_writer_wrapper_works(self): + """ + Ensure that we can wrap a smart_open gcs stream in a BufferedWriter, which + passes a memoryview object to the underlying stream in python >= 2.7 + """ + expected = u'не думай о секундах свысока' + + with smart_open.gcs.BufferedOutputBase(BUCKET_NAME, WRITE_BLOB_NAME) as fout: + with io.BufferedWriter(fout) as sub_out: + sub_out.write(expected.encode('utf-8')) + + with smart_open.open("gs://{}/{}".format(BUCKET_NAME, WRITE_BLOB_NAME), 'rb') as fin: + with io.TextIOWrapper(fin, encoding='utf-8') as text: + actual = text.read() + + self.assertEqual(expected, actual) + + def test_binary_iterator(self): + expected = u"выйду ночью в поле с конём".encode('utf-8').split(b' ') + put_to_bucket(contents=b"\n".join(expected)) + with smart_open.gcs.open(BUCKET_NAME, BLOB_NAME, 'rb') as fin: + actual = [line.rstrip() for line in fin] + self.assertEqual(expected, actual) + + def test_nonexisting_bucket(self): + expected = u"выйду ночью в поле с конём".encode('utf-8') + with self.assertRaises(google.api_core.exceptions.NotFound): + with smart_open.gcs.open('thisbucketdoesntexist', 'mykey', 'wb') as fout: + fout.write(expected) + + def test_read_nonexisting_key(self): + with self.assertRaises(google.api_core.exceptions.NotFound): + with smart_open.gcs.open(BUCKET_NAME, 'my_nonexisting_key', 'rb') as fin: + fin.read() + + def test_double_close(self): + text = u'там за туманами, вечными, пьяными'.encode('utf-8') + fout = smart_open.gcs.open(BUCKET_NAME, 'key', 'wb') + fout.write(text) + fout.close() + fout.close() + + def test_flush_close(self): + text = u'там за туманами, вечными, пьяными'.encode('utf-8') + fout = smart_open.gcs.open(BUCKET_NAME, 'key', 'wb') + fout.write(text) + fout.flush() + fout.close() + + def test_terminate(self): + text = u'там за туманами, вечными, пьяными'.encode('utf-8') + fout = smart_open.gcs.open(BUCKET_NAME, 'key', 'wb') + fout.write(text) + fout.terminate() + + with self.assertRaises(google.api_core.exceptions.NotFound): + with smart_open.gcs.open(BUCKET_NAME, 'key', 'rb') as fin: + fin.read() + + +@maybe_mock_gcs +class OpenTest(unittest.TestCase): + def setUp(self): + ignore_resource_warnings() + + def tearDown(self): + cleanup_bucket() + + def test_read_never_returns_none(self): + """read should never return None.""" + test_string = u"ветер по морю гуляет..." + with smart_open.gcs.open(BUCKET_NAME, BLOB_NAME, "wb") as fout: + fout.write(test_string.encode('utf8')) + + r = smart_open.gcs.open(BUCKET_NAME, BLOB_NAME, "rb") + self.assertEqual(r.read(), test_string.encode("utf-8")) + self.assertEqual(r.read(), b"") + self.assertEqual(r.read(), b"") + + +class MakeRangeStringTest(unittest.TestCase): + def test_no_stop(self): + start, stop = 1, None + self.assertEqual(smart_open.gcs._make_range_string(start, stop), 'bytes 1-/*') + + def test_stop(self): + start, stop = 1, 2 + self.assertEqual(smart_open.gcs._make_range_string(start, stop), 'bytes 1-2/*') + + +if __name__ == '__main__': + logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO) + unittest.main() diff --git a/smart_open/tests/test_smart_open.py b/smart_open/tests/test_smart_open.py index 2ea0a359..d16ac4e3 100644 --- a/smart_open/tests/test_smart_open.py +++ b/smart_open/tests/test_smart_open.py @@ -40,7 +40,7 @@ class ParseUriTest(unittest.TestCase): def test_scheme(self): """Do URIs schemes parse correctly?""" # supported schemes - for scheme in ("s3", "s3a", "s3n", "hdfs", "file", "http", "https"): + for scheme in ("s3", "s3a", "s3n", "hdfs", "file", "http", "https", "gs"): parsed_uri = smart_open_lib._parse_uri(scheme + "://mybucket/mykey") self.assertEqual(parsed_uri.scheme, scheme) @@ -273,6 +273,20 @@ def test_ssh_complex_password_with_colon(self): uri = smart_open_lib._parse_uri(as_string) self.assertEqual(uri.password, 'some:complex@password$$') + def test_gs_uri(self): + """Do GCS URIs parse correctly?""" + # correct uri without credentials + parsed_uri = smart_open_lib._parse_uri("gs://mybucket/myblob") + self.assertEqual(parsed_uri.scheme, "gs") + self.assertEqual(parsed_uri.bucket_id, "mybucket") + self.assertEqual(parsed_uri.blob_id, "myblob") + + def test_gs_uri_contains_slash(self): + parsed_uri = smart_open_lib._parse_uri("gs://mybucket/mydir/myblob") + self.assertEqual(parsed_uri.scheme, "gs") + self.assertEqual(parsed_uri.bucket_id, "mybucket") + self.assertEqual(parsed_uri.blob_id, "mydir/myblob") + class SmartOpenHttpTest(unittest.TestCase): """