diff --git a/setup.py b/setup.py index 561d8ae9..64cbc9f4 100644 --- a/setup.py +++ b/setup.py @@ -57,13 +57,14 @@ def read(fname): aws_deps = ['boto3'] gcp_deps = ['google-cloud-storage'] +asb_deps = ['azure-storage-blob', 'azure-common', 'azure-core'] -all_deps = install_requires + aws_deps + gcp_deps +all_deps = install_requires + aws_deps + gcp_deps + asb_deps setup( name='smart_open', version=__version__, - description='Utils for streaming large files (S3, HDFS, GCS, gzip, bz2...)', + description='Utils for streaming large files (S3, HDFS, GCS, ASB, gzip, bz2...)', long_description=read('README.rst'), packages=find_packages(), @@ -79,7 +80,7 @@ def read(fname): url='https://github.com/piskvorky/smart_open', download_url='http://pypi.python.org/pypi/smart_open', - keywords='file streaming, s3, hdfs, gcs', + keywords='file streaming, s3, hdfs, gcs, asb', license='MIT', platforms='any', @@ -93,6 +94,7 @@ def read(fname): 'test': tests_require, 'aws': aws_deps, 'gcp': gcp_deps, + 'asb': asb_deps, 'all': all_deps, }, diff --git a/smart_open/asb.py b/smart_open/asb.py new file mode 100644 index 00000000..2a31c24a --- /dev/null +++ b/smart_open/asb.py @@ -0,0 +1,461 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2020 Radim Rehurek +# Copyright (C) 2020 Nicolas Mitchell +# +# This code is distributed under the terms and conditions +# from the MIT License (MIT). +# +"""Implements file-like objects for reading and writing to/from Azure Storage Blob (ASB).""" + +import io +import logging +import base64 + +import smart_open.bytebuffer +import smart_open.constants + +import azure.storage.blob +import azure.core.exceptions + +logger = logging.getLogger(__name__) + +_BINARY_TYPES = (bytes, bytearray, memoryview) +"""Allowed binary buffer types for writing to the underlying Azure Storage Blob stream""" + +SCHEME = "asb" +"""Supported scheme for Azure Storage Blob in smart_open endpoint URL""" + +_MIN_MIN_PART_SIZE = _REQUIRED_CHUNK_MULTIPLE = 4 * 1024**2 +"""Azure requires you to upload in multiples of 4MB, except for the last part.""" + +_DEFAULT_MIN_PART_SIZE = 64 * 1024**2 +"""Default minimum part size for Azure Cloud Storage multipart uploads is 64MB""" + +DEFAULT_BUFFER_SIZE = 4 * 1024**2 +"""Default buffer size for working with Azure Storage Blob is 256MB +https://docs.microsoft.com/en-us/rest/api/storageservices/understanding-block-blobs--append-blobs--and-page-blobs +""" + + +def parse_uri(uri_as_string): + sr = smart_open.utils.safe_urlsplit(uri_as_string) + assert sr.scheme == SCHEME + container_id = sr.netloc + blob_id = sr.path.lstrip('/') + return dict(scheme=SCHEME, container_id=container_id, blob_id=blob_id) + + +def open_uri(uri, mode, transport_params): + parsed_uri = parse_uri(uri) + kwargs = smart_open.utils.check_kwargs(open, transport_params) + return open(parsed_uri['container_id'], parsed_uri['blob_id'], mode, **kwargs) + + +def open( + container_id, + blob_id, + mode, + buffer_size=DEFAULT_BUFFER_SIZE, + client=None, # type: azure.storage.blob.azure.storage.blob.BlobServiceClient + ): + """Open an Azure Storage Blob blob for reading or writing. + + Parameters + ---------- + bucket_id: str + The name of the bucket this object resides in. + blob_id: str + The name of the blob within the bucket. + mode: str + The mode for opening the object. Must be either "rb" or "wb". + buffer_size: int, optional + The buffer size to use when performing I/O. For reading only. + client: azure.storage.blob.azure.storage.blob.BlobServiceClient, optional + The Azure Storage Blob client to use when working with azure-storage-blob. + + """ + if mode == smart_open.constants.READ_BINARY: + return Reader( + container_id, + blob_id, + buffer_size=buffer_size, + line_terminator=smart_open.constants.BINARY_NEWLINE, + client=client, + ) + elif mode == smart_open.constants.WRITE_BINARY: + return Writer( + container_id, + blob_id, + client=client, + ) + else: + raise NotImplementedError('Azure Storage Blob support for mode %r not implemented' % mode) + + +class _RawReader(object): + """Read an Azure Storage Blob file.""" + + def __init__(self, asb_blob, size): + # type: (azure.storage.blob.BlobClient, int) -> None + self._blob = asb_blob + self._size = size + self._position = 0 + + def seek(self, position): + """Seek to the specified position (byte offset) in the Azure Storage Blob blob. + + :param int position: The byte offset from the beginning of the blob. + + Returns the position after seeking. + """ + self._position = position + return self._position + + def read(self, size=-1): + if self._position >= self._size: + return b'' + binary = self._download_blob_chunk(size) + self._position += len(binary) + return binary + + def _download_blob_chunk(self, size): + if self._size == self._position: + # + # When reading, we can't seek to the first byte of an empty file. + # Similarly, we can't seek past the last byte. Do nothing here. + # + return b'' + elif size == -1: + stream = self._blob.download_blob(offset=self._position) + else: + stream = self._blob.download_blob(offset=self._position, length=size) + if isinstance(stream, azure.storage.blob.StorageStreamDownloader): + binary = stream.readall() + else: + binary = stream.read() + return binary + + +class Reader(io.BufferedIOBase): + """Reads bytes from Azure Blob Storage. + + Implements the io.BufferedIOBase interface of the standard library. + + :raises azure.core.exceptions.ResourceNotFoundError: Raised when the blob to read from does not exist. + + """ + def __init__( + self, + container, + blob, + buffer_size=DEFAULT_BUFFER_SIZE, + line_terminator=smart_open.constants.BINARY_NEWLINE, + client=None, # type: azure.storage.blob.BlobServiceClient + ): + if client is None: + client = azure.storage.blob.BlobServiceClient() + self._container_client = client.get_container_client(container) + # type: azure.storage.blob.ContainerClient + + self._blob = self._container_client.get_blob_client(blob) + if self._blob is None: + raise azure.core.exceptions.ResourceNotFoundError( + 'blob %s not found in %s' % (blob, container) + ) + try: + self._size = self._blob.get_blob_properties()['size'] + except KeyError: + self._size = 0 + + self._raw_reader = _RawReader(self._blob, self._size) + self._position = 0 + self._current_part = smart_open.bytebuffer.ByteBuffer(buffer_size) + self._line_terminator = line_terminator + + # + # This member is part of the io.BufferedIOBase interface. + # + self.raw = None + + # + # Override some methods from io.IOBase. + # + def close(self): + """Flush and close this stream.""" + logger.debug("close: called") + self._blob = None + self._raw_reader = None + + def readable(self): + """Return True if the stream can be read from.""" + return True + + def seekable(self): + """If False, seek(), tell() and truncate() will raise IOError. + + We offer only seek support, and no truncate support.""" + return True + + # + # io.BufferedIOBase methods. + # + def detach(self): + """Unsupported.""" + raise io.UnsupportedOperation + + def seek(self, offset, whence=smart_open.constants.WHENCE_START): + """Seek to the specified position. + + :param int offset: The offset in bytes. + :param int whence: Where the offset is from. + + Returns the position after seeking.""" + logger.debug('seeking to offset: %r whence: %r', offset, whence) + if whence not in smart_open.constants.WHENCE_CHOICES: + raise ValueError('invalid whence %, expected one of %r' % (whence, + smart_open.constants.WHENCE_CHOICES)) + + if whence == smart_open.constants.WHENCE_START: + new_position = offset + elif whence == smart_open.constants.WHENCE_CURRENT: + new_position = self._position + offset + else: + new_position = self._size + offset + self._position = new_position + self._raw_reader.seek(new_position) + logger.debug('current_pos: %r', self._position) + + self._current_part.empty() + return self._position + + def tell(self): + """Return the current position within the file.""" + return self._position + + def truncate(self, size=None): + """Unsupported.""" + raise io.UnsupportedOperation + + def read(self, size=-1): + """Read up to size bytes from the object and return them.""" + if size == 0: + return b'' + elif size < 0: + self._position = self._size + return self._read_from_buffer() + self._raw_reader.read() + + # + # Return unused data first + # + if len(self._current_part) >= size: + return self._read_from_buffer(size) + + if self._position == self._size: + return self._read_from_buffer() + + self._fill_buffer() + return self._read_from_buffer(size) + + def read1(self, size=-1): + """This is the same as read().""" + return self.read(size=size) + + def readinto(self, b): + """Read up to len(b) bytes into b, and return the number of bytes read.""" + data = self.read(len(b)) + if not data: + return 0 + b[:len(data)] = data + return len(data) + + def readline(self, limit=-1): + """Read up to and including the next newline. Returns the bytes read.""" + if limit != -1: + raise NotImplementedError('limits other than -1 not implemented yet') + the_line = io.BytesIO() + while not (self._position == self._size and len(self._current_part) == 0): + # + # In the worst case, we're reading the unread part of self._current_part + # twice here, once in the if condition and once when calling index. + # + # This is sub-optimal, but better than the alternative: wrapping + # .index in a try..except, because that is slower. + # + remaining_buffer = self._current_part.peek() + if self._line_terminator in remaining_buffer: + next_newline = remaining_buffer.index(self._line_terminator) + the_line.write(self._read_from_buffer(next_newline + 1)) + break + else: + the_line.write(self._read_from_buffer()) + self._fill_buffer() + return the_line.getvalue() + + # + # Internal methods. + # + def _read_from_buffer(self, size=-1): + """Remove at most size bytes from our buffer and return them.""" + # logger.debug('reading %r bytes from %r byte-long buffer', size, len(self._current_part)) + size = size if size >= 0 else len(self._current_part) + part = self._current_part.read(size) + self._position += len(part) + # logger.debug('part: %r', part) + return part + + def _fill_buffer(self, size=-1): + size = max(size, self._current_part._chunk_size) + while len(self._current_part) < size and not self._position == self._size: + bytes_read = self._current_part.fill(self._raw_reader) + if bytes_read == 0: + logger.debug('reached EOF while filling buffer') + return True + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def __str__(self): + return "(%s, %r, %r)" % (self.__class__.__name__, + self._container.container_name, + self._blob.blob_name) + + def __repr__(self): + return "%s(container=%r, blob=%r)" % ( + self.__class__.__name__, self._container_client.container_name, self._blob.blob_name, + ) + + +class Writer(io.BufferedIOBase): + """Writes bytes to Azure Storage Blob. + + Implements the io.BufferedIOBase interface of the standard library.""" + + def __init__( + self, + container, + blob, + min_part_size=_DEFAULT_MIN_PART_SIZE, + client=None, # type: azure.storage.blob.BlobServiceClient + ): + if client is None: + client = azure.storage.blob.BlobServiceClient() + self._client = client + self._container_client = self._client.get_container_client(container) + # type: azure.storage.blob.ContainerClient + self._blob = self._container_client.get_blob_client(blob) # type: azure.storage.blob.BlobClient + self._min_part_size = min_part_size + + self._total_size = 0 + self._total_parts = 0 + self._bytes_uploaded = 0 + self._current_part = io.BytesIO() + self._block_list = [] + + # + # This member is part of the io.BufferedIOBase interface. + # + self.raw = None + + def flush(self): + pass + + # + # Override some methods from io.IOBase. + # + def close(self): + logger.debug("closing") + if not self.closed: + self._client = None + logger.debug("successfully closed") + + @property + def closed(self): + return self._client is None + + def writable(self): + """Return True if the stream supports writing.""" + return True + + def tell(self): + """Return the current stream position.""" + return self._total_size + + # + # io.BufferedIOBase methods. + # + def detach(self): + raise io.UnsupportedOperation("detach() not supported") + + def write(self, b): + """Write the given bytes (binary string) to the Azure Storage Blob file. + + There's buffering happening under the covers, so this may not actually + do any HTTP transfer right away.""" + + if not isinstance(b, _BINARY_TYPES): + raise TypeError("input must be one of %r, got: %r" % (_BINARY_TYPES, type(b))) + + self._current_part.write(b) + self._total_size += len(b) + if len(b) > 0: + self._upload_part() + + return len(b) + + def _upload_part(self): + part_num = self._total_parts + 1 + + # + # Here we upload the largest amount possible given Azure Storage Blob's restriction + # of parts being multiples of 4MB, except for the last one. + # + content_length = self._current_part.tell() + range_stop = self._bytes_uploaded + content_length - 1 + + # + # The block_id correspond to the index of the content base64 encoded. + # + block_id = base64.b64encode(str(self._bytes_uploaded).encode()) + self._current_part.seek(0) + self._blob.stage_block(block_id, self._current_part.read(content_length)) + if block_id not in [block_blob['id'] for block_blob in self._block_list]: + self._block_list.append(azure.storage.blob.BlobBlock(block_id=block_id)) + + logger.info( + "uploading part #%i, %i bytes (total %.3fGB)", + part_num, content_length, range_stop / 1024.0 ** 3, + ) + + self._blob.commit_block_list(self._block_list) + self._total_parts += 1 + self._bytes_uploaded += content_length + + # + # For the last part, the below _current_part handling is a NOOP. + # + self._current_part = io.BytesIO(self._current_part.read()) + self._current_part.seek(0, io.SEEK_END) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def __str__(self): + return "(%s, %r, %r)" % ( + self.__class__.__name__, + self._container_client.container_name, + self._blob.blob_name + ) + + def __repr__(self): + return "%s(container=%r, blob=%r, min_part_size=%r)" % ( + self.__class__.__name__, + self._container_client.container_name, + self._blob.blob_name, + ) diff --git a/smart_open/tests/test_asb.py b/smart_open/tests/test_asb.py new file mode 100644 index 00000000..ea60498d --- /dev/null +++ b/smart_open/tests/test_asb.py @@ -0,0 +1,719 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2020 Radim Rehurek +# Copyright (C) 2020 Nicolas Mitchell +# +# This code is distributed under the terms and conditions +# from the MIT License (MIT). +# +import gzip +import io +import logging +import os +import time +import uuid +import unittest +from collections import OrderedDict + +import smart_open +import smart_open.constants + +import azure.storage.blob +import azure.common +import azure.core.exceptions + +CONTAINER_NAME = 'test-smartopen-{}'.format(uuid.uuid4().hex) +BLOB_NAME = 'test-blob' +DISABLE_MOCKS = os.environ.get('SO_DISABLE_Azure Storage Blob_MOCKS') == "1" + +logger = logging.getLogger(__name__) + + +class FakeBlobClient(object): + # From Azure's BlobClient API + # https://azuresdkdocs.blob.core.windows.net/$web/python/azure-storage-blob/12.0.0/azure.storage.blob.html#azure.storage.blob.BlobClient + def __init__(self, container_client, name): + self._container_client = container_client # type: FakeContainerClient + self.blob_name = name + self.metadata = dict(size=0) + self.__contents = io.BytesIO() + self._staged_contents = {} + + def commit_block_list(self, block_list): + data = b''.join([self._staged_contents[block_blob['id']] for block_blob in block_list]) + self.__contents = io.BytesIO(data) + self.set_blob_metadata(dict(size=len(data))) + self._container_client.register_blob_client(self) + + def delete_blob(self): + self._container_client.delete_blob(self) + + def download_blob(self, offset=None, length=None): + if offset is None: + return self.__contents + self.__contents.seek(offset) + return io.BytesIO(self.__contents.read(length)) + + def get_blob_properties(self): + return self.metadata + + def set_blob_metadata(self, metadata): + self.metadata = metadata + + def stage_block(self, block_id, data): + self._staged_contents[block_id] = data + + def upload_blob(self, data, length=None, metadata=None): + if metadata is not None: + self.set_blob_metadata(metadata) + self.__contents = io.BytesIO(data[:length]) + self.set_blob_metadata(dict(size=len(data[:length]))) + self._container_client.register_blob_client(self) + + +class FakeBlobClientTest(unittest.TestCase): + def setUp(self): + self.blob_service_client = FakeBlobServiceClient() + self.container_client = FakeContainerClient(self.blob_service_client, 'test-container') + self.blob_client = FakeBlobClient(self.container_client, 'test-blob.txt') + + def test_delete_blob(self): + data = b'Lorem ipsum' + self.blob_client.upload_blob(data) + self.assertEqual(self.container_client.list_blobs(), [self.blob_client.blob_name]) + self.blob_client.delete_blob() + self.assertEqual(self.container_client.list_blobs(), []) + + def test_upload_blob(self): + data = b'Lorem ipsum' + self.blob_client.upload_blob(data) + actual = self.blob_client.download_blob().read() + self.assertEqual(actual, data) + + +class FakeContainerClient(object): + # From Azure's ContainerClient API + # https://docs.microsoft.com/fr-fr/python/api/azure-storage-blob/azure.storage.blob.containerclient?view=azure-python + def __init__(self, blob_service_client, name): + self.blob_service_client = blob_service_client # type: FakeBlobServiceClient + self.container_name = name + self.metadata = {} + self.__blob_clients = OrderedDict() + + def create_container(self, metadata): + self.metadata = metadata + + def delete_blob(self, blob): + del self.__blob_clients[blob.blob_name] + + def delete_blobs(self): + self.__blob_clients = OrderedDict() + + def delete_container(self): + self.blob_service_client.delete_container(self.container_name) + + def download_blob(self, blob): + if blob.blob_name not in list(self.__blob_clients.keys()): + raise azure.core.exceptions.ResourceNotFoundError('The specified blob does not exist.') + blob_client = self.__blob_clients[blob.blob_name] + blob_content = blob_client.download_blob() + return blob_content + + def get_blob_client(self, blob_name): + return self.__blob_clients.get(blob_name, FakeBlobClient(self, blob_name)) + + def get_container_properties(self): + return self.metadata + + def list_blobs(self): + return list(self.__blob_clients.keys()) + + def upload_blob(self, blob_name, data): + blob_client = FakeBlobClient(self, blob_name) + blob_client.upload_blob(data) + self.__blob_clients[blob_name] = blob_client + + def register_blob_client(self, blob_client): + self.__blob_clients[blob_client.blob_name] = blob_client + + +class FakeContainerClientTest(unittest.TestCase): + def setUp(self): + self.blob_service_client = FakeBlobServiceClient() + self.container_client = FakeContainerClient(self.blob_service_client, 'test-container') + + def test_nonexistent_blob(self): + blob_client = self.container_client.get_blob_client('test-blob.txt') + with self.assertRaises(azure.core.exceptions.ResourceNotFoundError): + self.container_client.download_blob(blob_client) + + def test_delete_blob(self): + blob_name = 'test-blob.txt' + data = b'Lorem ipsum' + self.container_client.upload_blob(blob_name, data) + self.assertEqual(self.container_client.list_blobs(), [blob_name]) + blob_client = FakeBlobClient(self.container_client, 'test-blob.txt') + self.container_client.delete_blob(blob_client) + self.assertEqual(self.container_client.list_blobs(), []) + + def test_delete_blobs(self): + blob_name_1 = 'test-blob-1.txt' + blob_name_2 = 'test-blob-2.txt' + data = b'Lorem ipsum' + self.container_client.upload_blob(blob_name_1, data) + self.container_client.upload_blob(blob_name_2, data) + self.assertEqual(self.container_client.list_blobs(), [blob_name_1, blob_name_2]) + + def test_delete_container(self): + container_name = 'test-container' + container_client = self.blob_service_client.create_container(container_name) + self.assertEqual(self.blob_service_client.get_container_client(container_name).container_name, + container_name) + container_client.delete_container() + with self.assertRaises(azure.core.exceptions.ResourceNotFoundError): + self.blob_service_client.get_container_client(container_name) + + def test_list_blobs(self): + blob_name_1 = 'test-blob-1.txt' + blob_name_2 = 'test-blob-2.txt' + data = b'Lorem ipsum' + self.container_client.upload_blob(blob_name_1, data) + self.container_client.upload_blob(blob_name_2, data) + self.assertEqual(self.container_client.list_blobs(), [blob_name_1, blob_name_2]) + self.container_client.delete_blobs() + self.assertEqual(self.container_client.list_blobs(), []) + + def test_upload_blob(self): + blob_name = 'test-blob.txt' + data = b'Lorem ipsum' + self.container_client.upload_blob(blob_name, data) + blob_client = self.container_client.get_blob_client(blob_name) + actual = self.container_client.download_blob(blob_client).read() + self.assertEqual(actual, data) + + +class FakeBlobServiceClient(object): + # From Azure's BlobServiceClient API + # https://docs.microsoft.com/fr-fr/python/api/azure-storage-blob/azure.storage.blob.blobserviceclient?view=azure-python + def __init__(self): + self.__container_clients = OrderedDict() + + def create_container(self, container_name, metadata=None): + if container_name in self.__container_clients: + raise azure.core.exceptions.ResourceExistsError('The specified container already exists.') + container_client = FakeContainerClient(self, container_name) + if metadata is not None: + container_client.create_container(metadata) + self.__container_clients[container_name] = container_client + return container_client + + def delete_container(self, container_name): + del self.__container_clients[container_name] + + def get_blob_client(self, container, blob): + container = self.__container_clients[container] + blob_client = container.get_blob_client(blob) + return blob_client + + def get_container_client(self, container): + if container not in self.__container_clients: + raise azure.core.exceptions.ResourceNotFoundError('The specified container does not exist.') + return self.__container_clients[container] + + +class FakeBlobServiceClientTest(unittest.TestCase): + def setUp(self): + self.blob_service_client = FakeBlobServiceClient() + + def test_nonexistent_container(self): + with self.assertRaises(azure.core.exceptions.ResourceNotFoundError): + self.blob_service_client.get_container_client('test-container') + + def test_create_container(self): + container_name = 'test_container' + expected = self.blob_service_client.create_container(container_name) + actual = self.blob_service_client.get_container_client(container_name) + self.assertEqual(actual, expected) + + def test_duplicate_container(self): + container_name = 'test-container' + self.blob_service_client.create_container(container_name) + with self.assertRaises(azure.core.exceptions.ResourceExistsError): + self.blob_service_client.create_container(container_name) + + def test_delete_container(self): + container_name = 'test_container' + self.blob_service_client.create_container(container_name) + self.blob_service_client.delete_container(container_name) + with self.assertRaises(azure.core.exceptions.ResourceNotFoundError): + self.blob_service_client.get_container_client(container_name) + + def test_get_blob_client(self): + container_name = 'test_container' + blob_name = 'test-blob.txt' + self.blob_service_client.create_container(container_name) + blob_client = self.blob_service_client.get_blob_client(container_name, blob_name) + self.assertEqual(blob_client.blob_name, blob_name) + + +if DISABLE_MOCKS: + """If mocks are disabled, allow to use the Azurite local Azure Storage API + https://github.com/Azure/Azurite + To use locally: + docker run -p 10000:10000 -p 10001:10001 mcr.microsoft.com/azure-storage/azurite + """ + # use Azurite default connection string + CONNECT_STR = 'DefaultEndpointsProtocol=http;' \ + 'AccountName=devstoreaccount1;' \ + 'AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/' \ + 'K1SZFPTOtr/KBHBeksoGMGw==;' \ + 'BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;' + test_blob_service_client = azure.storage.blob.BlobServiceClient.from_connection_string(CONNECT_STR) +else: + test_blob_service_client = FakeBlobServiceClient() + + +def get_container_client(): + return test_blob_service_client.get_container_client(container=CONTAINER_NAME) + + +def cleanup_container(): + container_client = get_container_client() + container_client.delete_blobs() + + +def put_to_container(blob_name, contents, num_attempts=12, sleep_time=5): + logger.debug('%r', locals()) + + # + # In real life, it can take a few seconds for the container to become ready. + # If we try to write to the key while the container while it isn't ready, we + # will get a StorageError: NotFound. + # + for attempt in range(num_attempts): + try: + container_client = get_container_client() + container_client.upload_blob(blob_name, contents) + return + except azure.common.AzureHttpError as err: + logger.error('caught %r, retrying', err) + time.sleep(sleep_time) + + assert False, 'failed to create container %s after %d attempts' % (CONTAINER_NAME, num_attempts) + + +def setUpModule(): # noqa + """Called once by unittest when initializing this module. Set up the + test Azure container. + """ + test_blob_service_client.create_container(CONTAINER_NAME) + +def tearDownModule(): # noqa + """Called once by unittest when tearing down this module. Empty and + removes the test Azure container. + """ + try: + container_client = get_container_client() + container_client.delete_container() + except azure.common.AzureHttpError: + pass + + +class ReaderTest(unittest.TestCase): + + def tearDown(self): + cleanup_container() + + def test_iter(self): + """Are Azure Storage Blob files iterated over correctly?""" + expected = u"hello wořld\nhow are you?".encode('utf8') + blob_name = "test_iter_%s" % BLOB_NAME + put_to_container(blob_name, contents=expected) + + # connect to fake Azure Storage Blob and read from the fake key we filled above + fin = smart_open.asb.Reader(CONTAINER_NAME, blob_name, client=test_blob_service_client) + output = [line.rstrip(b'\n') for line in fin] + self.assertEqual(output, expected.split(b'\n')) + + def test_iter_context_manager(self): + # same thing but using a context manager + expected = u"hello wořld\nhow are you?".encode('utf8') + blob_name = "test_iter_context_manager_%s" % BLOB_NAME + put_to_container(blob_name, contents=expected) + + with smart_open.asb.Reader( + CONTAINER_NAME, + blob_name, + client=test_blob_service_client + ) as fin: + output = [line.rstrip(b'\n') for line in fin] + self.assertEqual(output, expected.split(b'\n')) + + def test_read(self): + """Are Azure Storage Blob files read correctly?""" + content = u"hello wořld\nhow are you?".encode('utf8') + blob_name = "test_read_%s" % BLOB_NAME + put_to_container(blob_name, contents=content) + logger.debug('content: %r len: %r', content, len(content)) + + fin = smart_open.asb.Reader(CONTAINER_NAME, blob_name, client=test_blob_service_client) + self.assertEqual(content[:6], fin.read(6)) + self.assertEqual(content[6:14], fin.read(8)) # ř is 2 bytes + self.assertEqual(content[14:], fin.read()) # read the rest + + def test_seek_beginning(self): + """Does seeking to the beginning of Azure Storage Blob files work correctly?""" + content = u"hello wořld\nhow are you?".encode('utf8') + blob_name = "test_seek_beginning_%s" % BLOB_NAME + put_to_container(blob_name, contents=content) + + fin = smart_open.asb.Reader(CONTAINER_NAME, blob_name, client=test_blob_service_client) + self.assertEqual(content[:6], fin.read(6)) + self.assertEqual(content[6:14], fin.read(8)) # ř is 2 bytes + + fin.seek(0) + self.assertEqual(content, fin.read()) # no size given => read whole file + + fin.seek(0) + self.assertEqual(content, fin.read(-1)) # same thing + + def test_seek_start(self): + """Does seeking from the start of Azure Storage Blob files work correctly?""" + content = u"hello wořld\nhow are you?".encode('utf8') + blob_name = "test_seek_start_%s" % BLOB_NAME + put_to_container(blob_name, contents=content) + + fin = smart_open.asb.Reader(CONTAINER_NAME, blob_name, client=test_blob_service_client) + seek = fin.seek(6) + self.assertEqual(seek, 6) + self.assertEqual(fin.tell(), 6) + self.assertEqual(fin.read(6), u'wořld'.encode('utf-8')) + + def test_seek_current(self): + """Does seeking from the middle of Azure Storage Blob files work correctly?""" + content = u"hello wořld\nhow are you?".encode('utf8') + blob_name = "test_seek_current_%s" % BLOB_NAME + put_to_container(blob_name, contents=content) + + fin = smart_open.asb.Reader(CONTAINER_NAME, blob_name, client=test_blob_service_client) + self.assertEqual(fin.read(5), b'hello') + seek = fin.seek(1, whence=smart_open.constants.WHENCE_CURRENT) + self.assertEqual(seek, 6) + self.assertEqual(fin.read(6), u'wořld'.encode('utf-8')) + + def test_seek_end(self): + """Does seeking from the end of Azure Storage Blob files work correctly?""" + content = u"hello wořld\nhow are you?".encode('utf8') + blob_name = "test_seek_end_%s" % BLOB_NAME + put_to_container(blob_name, contents=content) + + fin = smart_open.asb.Reader(CONTAINER_NAME, blob_name, client=test_blob_service_client) + seek = fin.seek(-4, whence=smart_open.constants.WHENCE_END) + self.assertEqual(seek, len(content) - 4) + self.assertEqual(fin.read(), b'you?') + + def test_detect_eof(self): + content = u"hello wořld\nhow are you?".encode('utf8') + blob_name = "test_detect_eof_%s" % BLOB_NAME + put_to_container(blob_name, contents=content) + + fin = smart_open.asb.Reader(CONTAINER_NAME, blob_name, client=test_blob_service_client) + fin.read() + eof = fin.tell() + self.assertEqual(eof, len(content)) + fin.seek(0, whence=smart_open.constants.WHENCE_END) + self.assertEqual(eof, fin.tell()) + + def test_read_gzip(self): + expected = u'раcцветали яблони и груши, поплыли туманы над рекой...'.encode('utf-8') + buf = io.BytesIO() + buf.close = lambda: None # keep buffer open so that we can .getvalue() + with gzip.GzipFile(fileobj=buf, mode='w') as zipfile: + zipfile.write(expected) + blob_name = "test_read_gzip_%s" % BLOB_NAME + put_to_container(blob_name, contents=buf.getvalue()) + + # + # Make sure we're reading things correctly. + # + with smart_open.asb.Reader( + CONTAINER_NAME, + blob_name, + client=test_blob_service_client + ) as fin: + self.assertEqual(fin.read(), buf.getvalue()) + + # + # Make sure the buffer we wrote is legitimate gzip. + # + sanity_buf = io.BytesIO(buf.getvalue()) + with gzip.GzipFile(fileobj=sanity_buf) as zipfile: + self.assertEqual(zipfile.read(), expected) + + logger.debug('starting actual test') + with smart_open.asb.Reader( + CONTAINER_NAME, + blob_name, + client=test_blob_service_client + ) as fin: + with gzip.GzipFile(fileobj=fin) as zipfile: + actual = zipfile.read() + + self.assertEqual(expected, actual) + + def test_readline(self): + content = b'englishman\nin\nnew\nyork\n' + blob_name = "test_readline_%s" % BLOB_NAME + put_to_container(blob_name, contents=content) + + with smart_open.asb.Reader( + CONTAINER_NAME, + blob_name, + client=test_blob_service_client + ) as fin: + fin.readline() + self.assertEqual(fin.tell(), content.index(b'\n')+1) + + fin.seek(0) + actual = list(fin) + self.assertEqual(fin.tell(), len(content)) + + expected = [b'englishman\n', b'in\n', b'new\n', b'york\n'] + self.assertEqual(expected, actual) + + def test_readline_tiny_buffer(self): + content = b'englishman\nin\nnew\nyork\n' + blob_name = "test_readline_tiny_buffer_%s" % BLOB_NAME + put_to_container(blob_name, contents=content) + + with smart_open.asb.Reader( + CONTAINER_NAME, + blob_name, + client=test_blob_service_client, + buffer_size=8 + ) as fin: + actual = list(fin) + + expected = [b'englishman\n', b'in\n', b'new\n', b'york\n'] + self.assertEqual(expected, actual) + + def test_read0_does_not_return_data(self): + content = b'englishman\nin\nnew\nyork\n' + blob_name = "test_read0_does_not_return_data_%s" % BLOB_NAME + put_to_container(blob_name, contents=content) + + with smart_open.asb.Reader( + CONTAINER_NAME, + blob_name, + client=test_blob_service_client + ) as fin: + data = fin.read(0) + + self.assertEqual(data, b'') + + def test_read_past_end(self): + content = b'englishman\nin\nnew\nyork\n' + blob_name = "test_read_past_end_%s" % BLOB_NAME + put_to_container(blob_name, contents=content) + + with smart_open.asb.Reader( + CONTAINER_NAME, + blob_name, + client=test_blob_service_client + ) as fin: + data = fin.read(100) + + self.assertEqual(data, content) + + +class WriterTest(unittest.TestCase): + """ + Test writing into asb files. + + """ + + def tearDown(self): + cleanup_container() + + def test_write_01(self): + """Does writing into Azure Storage Blob work correctly?""" + test_string = u"žluťoučký koníček".encode('utf8') + blob_name = "test_write_01_%s" % BLOB_NAME + + with smart_open.asb.Writer( + CONTAINER_NAME, + blob_name, + client=test_blob_service_client + ) as fout: + fout.write(test_string) + + output = list(smart_open.open( + "asb://%s/%s" % (CONTAINER_NAME, blob_name), + "rb", + transport_params=dict(client=test_blob_service_client)) + ) + self.assertEqual(output, [test_string]) + + def test_incorrect_input(self): + """Does gcs write fail on incorrect input?""" + blob_name = "test_incorrect_input_%s" % BLOB_NAME + try: + with smart_open.asb.Writer( + CONTAINER_NAME, + blob_name, + client=test_blob_service_client + ) as fin: + fin.write(None) + except TypeError: + pass + else: + self.fail() + + def test_write_02(self): + """Does Azure Storage Blob write unicode-utf8 conversion work?""" + blob_name = "test_write_02_%s" % BLOB_NAME + smart_open_write = smart_open.asb.Writer(CONTAINER_NAME, blob_name, client=test_blob_service_client) + smart_open_write.tell() + logger.info("smart_open_write: %r", smart_open_write) + with smart_open_write as fout: + fout.write(u"testžížáč".encode("utf-8")) + self.assertEqual(fout.tell(), 14) + + def test_write_03(self): + """Do multiple writes work correctly?""" + # write + blob_name = "test_write_03_%s" % BLOB_NAME + smart_open_write = smart_open.asb.Writer(CONTAINER_NAME, blob_name, client=test_blob_service_client) + local_write = io.BytesIO() + + with smart_open_write as fout: + first_part = b"t" * 64 + fout.write(first_part) + local_write.write(first_part) + self.assertEqual(fout.tell(), 64) + + second_part = b"t\n" + fout.write(second_part) + local_write.write(second_part) + self.assertEqual(fout.tell(), 66) + self.assertEqual(fout._total_parts, 2) + + third_part = b"t" + fout.write(third_part) + local_write.write(third_part) + self.assertEqual(fout.tell(), 67) + self.assertEqual(fout._total_parts, 3) + + fourth_part = b"t" * 1 + fout.write(fourth_part) + local_write.write(fourth_part) + self.assertEqual(fout.tell(), 68) + self.assertEqual(fout._total_parts, 4) + + # read back the same key and check its content + output = list(smart_open.open( + "asb://%s/%s" % (CONTAINER_NAME, blob_name), + transport_params=dict(client=test_blob_service_client)) + ) + local_write.seek(0) + actual = [line.decode("utf-8") for line in list(local_write)] + self.assertEqual(output, actual) + + def test_write_04(self): + """Does writing no data cause key with an empty value to be created?""" + blob_name = "test_write_04_%s" % BLOB_NAME + smart_open_write = smart_open.asb.Writer(CONTAINER_NAME, blob_name, client=test_blob_service_client) + with smart_open_write as fout: # noqa + pass + + # read back the same key and check its content + output = list(smart_open.open( + "asb://%s/%s" % (CONTAINER_NAME, blob_name), + transport_params=dict(client=test_blob_service_client)) + ) + self.assertEqual(output, []) + + def test_gzip(self): + expected = u'а не спеть ли мне песню... о любви'.encode('utf-8') + blob_name = "test_gzip_%s" % BLOB_NAME + with smart_open.asb.Writer( + CONTAINER_NAME, + blob_name, + client=test_blob_service_client + ) as fout: + with gzip.GzipFile(fileobj=fout, mode='w') as zipfile: + zipfile.write(expected) + + with smart_open.asb.Reader( + CONTAINER_NAME, + blob_name, + client=test_blob_service_client + ) as fin: + with gzip.GzipFile(fileobj=fin) as zipfile: + actual = zipfile.read() + + self.assertEqual(expected, actual) + + def test_buffered_writer_wrapper_works(self): + """ + Ensure that we can wrap a smart_open gcs stream in a BufferedWriter, which + passes a memoryview object to the underlying stream in python >= 2.7 + """ + expected = u'не думай о секундах свысока' + blob_name = "test_buffered_writer_wrapper_works_%s" % BLOB_NAME + + with smart_open.asb.Writer( + CONTAINER_NAME, + blob_name, + client=test_blob_service_client + ) as fout: + with io.BufferedWriter(fout) as sub_out: + sub_out.write(expected.encode('utf-8')) + + with smart_open.open( + "asb://%s/%s" % (CONTAINER_NAME, blob_name), + 'rb', + transport_params=dict(client=test_blob_service_client) + ) as fin: + with io.TextIOWrapper(fin, encoding='utf-8') as text: + actual = text.read() + + self.assertEqual(expected, actual) + + def test_binary_iterator(self): + expected = u"выйду ночью в поле с конём".encode('utf-8').split(b' ') + blob_name = "test_binary_iterator_%s" % BLOB_NAME + put_to_container(blob_name=blob_name, contents=b"\n".join(expected)) + with smart_open.asb.open( + CONTAINER_NAME, + blob_name, + 'rb', + client=test_blob_service_client + ) as fin: + actual = [line.rstrip() for line in fin] + self.assertEqual(expected, actual) + + def test_nonexisting_container(self): + expected = u"выйду ночью в поле с конём".encode('utf-8') + with self.assertRaises(azure.core.exceptions.ResourceNotFoundError): + with smart_open.asb.open( + 'thiscontainerdoesntexist', + 'mykey', + 'wb', + client=test_blob_service_client + ) as fout: + fout.write(expected) + + def test_double_close(self): + text = u'там за туманами, вечными, пьяными'.encode('utf-8') + fout = smart_open.asb.open(CONTAINER_NAME, 'key', 'wb', client=test_blob_service_client) + fout.write(text) + fout.close() + fout.close() + + def test_flush_close(self): + text = u'там за туманами, вечными, пьяными'.encode('utf-8') + fout = smart_open.asb.open(CONTAINER_NAME, 'key', 'wb', client=test_blob_service_client) + fout.write(text) + fout.flush() + fout.close() diff --git a/smart_open/tests/test_smart_open.py b/smart_open/tests/test_smart_open.py index be9d3858..6ba82b36 100644 --- a/smart_open/tests/test_smart_open.py +++ b/smart_open/tests/test_smart_open.py @@ -40,7 +40,7 @@ class ParseUriTest(unittest.TestCase): def test_scheme(self): """Do URIs schemes parse correctly?""" # supported schemes - for scheme in ("s3", "s3a", "s3n", "hdfs", "file", "http", "https", "gs"): + for scheme in ("s3", "s3a", "s3n", "hdfs", "file", "http", "https", "gs", "asb"): parsed_uri = smart_open_lib._parse_uri(scheme + "://mybucket/mykey") self.assertEqual(parsed_uri.scheme, scheme) diff --git a/smart_open/transport.py b/smart_open/transport.py index 2d43d862..b6bf9dbc 100644 --- a/smart_open/transport.py +++ b/smart_open/transport.py @@ -75,6 +75,7 @@ def get_transport(scheme): register_transport(smart_open.local_file) +register_transport('smart_open.asb') register_transport('smart_open.gcs') register_transport('smart_open.hdfs') register_transport('smart_open.http')