Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quick fix for dir_path and getsize for Azure Blob #1

Closed
wants to merge 13 commits into from
303 changes: 303 additions & 0 deletions zarr/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -1886,3 +1886,306 @@ def __delitem__(self, key):
with self._mutex:
self._invalidate_keys()
self._invalidate_value(key)


# utility functions for object stores


def _strip_prefix_from_path(path, prefix):
# normalized things will not have any leading or trailing slashes
path_norm = normalize_storage_path(path)
prefix_norm = normalize_storage_path(prefix)
if path_norm.startswith(prefix_norm):
return path_norm[(len(prefix_norm)+1):]
else:
return path


def _append_path_to_prefix(path, prefix):
return '/'.join([normalize_storage_path(prefix),
normalize_storage_path(path)])


def atexit_rmgcspath(bucket, path):
from google.cloud import storage
client = storage.Client()
bucket = client.get_bucket(bucket)
bucket.delete_blobs(bucket.list_blobs(prefix=path))


class GCSStore(MutableMapping):
"""Storage class using a Google Cloud Storage (GCS)

Parameters
----------
bucket_name : string
The name of the GCS bucket
prefix : string, optional
The prefix within the bucket (i.e. subdirectory)
client_kwargs : dict, optional
Extra options passed to ``google.cloud.storage.Client`` when connecting
to GCS

Notes
-----
In order to use this store, you must install the Google Cloud Storage
`Python Client Library <https://cloud.google.com/storage/docs/reference/libraries>`_.
You must also provide valid application credentials, either by setting the
``GOOGLE_APPLICATION_CREDENTIALS`` environment variable or via
`default credentials <https://cloud.google.com/sdk/gcloud/reference/auth/application-default/login>`_.
"""

def __init__(self, bucket_name, prefix=None, client_kwargs={}):

self.bucket_name = bucket_name
self.prefix = normalize_storage_path(prefix)
self.client_kwargs = client_kwargs
self.initialize_bucket()

def initialize_bucket(self):
from google.cloud import storage
# run `gcloud auth application-default login` from shell
client = storage.Client(**self.client_kwargs)
self.bucket = client.get_bucket(self.bucket_name)
# need to properly handle excpetions
import google.api_core.exceptions as exceptions
self.exceptions = exceptions

# needed for pickling
def __getstate__(self):
state = self.__dict__.copy()
del state['bucket']
del state['exceptions']
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.initialize_bucket()

def __enter__(self):
return self

def __exit__(self, *args):
pass

def full_path(self, path=None):
return _append_path_to_prefix(path, self.prefix)

def list_gcs_directory_blobs(self, path):
"""Return list of all blobs *directly* under a gcs prefix."""
prefix = normalize_storage_path(path) + '/'
return [blob.name for blob in
self.bucket.list_blobs(prefix=prefix, delimiter='/')]

# from https://github.com/GoogleCloudPlatform/google-cloud-python/issues/920
def list_gcs_subdirectories(self, path):
"""Return set of all "subdirectories" from a gcs prefix."""
prefix = normalize_storage_path(path) + '/'
iterator = self.bucket.list_blobs(prefix=prefix, delimiter='/')
prefixes = set()
for page in iterator.pages:
prefixes.update(page.prefixes)
# need to strip trailing slash to be consistent with os.listdir
return [path[:-1] for path in prefixes]

def list_gcs_directory(self, prefix, strip_prefix=True):
"""Return a list of all blobs and subdirectories from a gcs prefix."""
items = set()
items.update(self.list_gcs_directory_blobs(prefix))
items.update(self.list_gcs_subdirectories(prefix))
items = list(items)
if strip_prefix:
items = [_strip_prefix_from_path(path, prefix) for path in items]
return items

def listdir(self, path=None):
dir_path = self.full_path(path)
return sorted(self.list_gcs_directory(dir_path, strip_prefix=True))

def rmdir(self, path=None):
# make sure it's a directory
dir_path = normalize_storage_path(self.full_path(path)) + '/'
self.bucket.delete_blobs(self.bucket.list_blobs(prefix=dir_path))

def getsize(self, path=None):
# this function should *not* be recursive
# a lot of slash trickery is required to make this work right
full_path = self.full_path(path)
blob = self.bucket.get_blob(full_path)
if blob is not None:
return blob.size
else:
dir_path = normalize_storage_path(full_path) + '/'
blobs = self.bucket.list_blobs(prefix=dir_path, delimiter='/')
size = 0
for blob in blobs:
size += blob.size
return size

def clear(self):
self.rmdir()

def __getitem__(self, key):
blob_name = self.full_path(key)
blob = self.bucket.get_blob(blob_name)
if blob:
return blob.download_as_string()
else:
raise KeyError('Blob %s not found' % blob_name)

def __setitem__(self, key, value):
blob_name = self.full_path(key)
blob = self.bucket.blob(blob_name)
blob.upload_from_string(value)

def __delitem__(self, key):
blob_name = self.full_path(key)
try:
self.bucket.delete_blob(blob_name)
except self.exceptions.NotFound as er:
raise KeyError(er.message)

def __contains__(self, key):
blob_name = self.full_path(key)
return self.bucket.get_blob(blob_name) is not None

def __eq__(self, other):
return (
isinstance(other, GCSStore) and
self.bucket_name == other.bucket_name and
self.prefix == other.prefix
)

def __iter__(self):
blobs = self.bucket.list_blobs(prefix=self.prefix)
for blob in blobs:
yield _strip_prefix_from_path(blob.name, self.prefix)

def __len__(self):
iterator = self.bucket.list_blobs(prefix=self.prefix)
return len(list(iterator))

class ABSStore(MutableMapping):

def __init__(self, container_name, prefix, account_name, account_key):
self.account_name = account_name
self.account_key = account_key
self.container_name = container_name
self.prefix = normalize_storage_path(prefix)
self.initialize_container()

def initialize_container(self):
from azure.storage.blob import BlockBlobService
self.client = BlockBlobService(self.account_name, self.account_key)
# change logging level to deal with https://github.com/Azure/azure-storage-python/issues/437
# it would be better to set up a logging filter that throws out just the
# error logged when calling exists().
import logging
logging.basicConfig(level=logging.CRITICAL)

# needed for pickling
def __getstate__(self):
state = self.__dict__.copy()
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.initialize_container()

def __enter__(self):
return self

def __exit__(self, *args):
pass

def full_path(self, path=None):
return _append_path_to_prefix(path, self.prefix)

def __getitem__(self, key):
blob_name = '/'.join([self.prefix, key])
blob = self.client.get_blob_to_bytes(self.container_name, blob_name)
if blob:
return blob.content
else:
raise KeyError('Blob %s not found' % blob_name)

def __setitem__(self, key, value):
blob_name = '/'.join([self.prefix, key])
self.client.create_blob_from_text(self.container_name, blob_name, value)

def __delitem__(self, key):
raise NotImplementedError

def __eq__(self, other):
return (
isinstance(other, ABSStore) and
self.container_name == other.container_name and
self.prefix == other.prefix
)

def keys(self):
raise NotImplementedError

def __iter__(self):
raise NotImplementedError

def __len__(self):
raise NotImplementedError

def __contains__(self, key):
# this is where the logging error occurs. not sure why we are looking for a .zarray below every blob
blob_name = '/'.join([self.prefix, key])
if self.client.exists(self.container_name, blob_name):
return True
else:
return False

def list_abs_directory_blobs(self, prefix):
"""Return list of all blobs from an abs prefix."""
return [blob.name for blob in self.client.list_blobs(self.container_name)]

def list_abs_subdirectories(self, prefix):
"""Return list of all "subdirectories" from an abs prefix."""
return list(set([blob.name.rsplit('/', 1)[0] for blob in self.client.list_blobs(self.container_name) if '/' in blob.name]))

def list_abs_directory(self, prefix, strip_prefix=True):
"""Return a list of all blobs and subdirectories from an abs prefix."""
items = set()
items.update(self.list_abs_directory_blobs(prefix))
items.update(self.list_abs_subdirectories(prefix))
items = list(items)
if strip_prefix:
items = [_strip_prefix_from_path(path, prefix) for path in items]
return items

def dir_path(self, path=None):
store_path = normalize_storage_path(path)
# prefix is normalized to not have a trailing slash
dir_path = self.prefix
if store_path:
dir_path = os.path.join(dir_path, store_path)
else:
dir_path += '/'
return dir_path

def listdir(self, path=None):
dir_path = self.dir_path(path)
return sorted(self.list_abs_directory(dir_path, strip_prefix=True))

def rename(self, src_path, dst_path):
raise NotImplementedErrror

def rmdir(self, path=None):
dir_path = normalize_storage_path(self.full_path(path)) + '/'
for blob in self.client.list_blobs(self.container_name, prefix=dir_path):
self.client.delete_blob(self.container_name, blob.name)

def getsize(self, path=None):
dir_path = self.dir_path(path)
size = 0
for blob in self.client.list_blobs(self.container_name, prefix=dir_path):
size += blob.properties.content_length # from https://stackoverflow.com/questions/47694592/get-container-sizes-in-azure-blob-storage-using-python
return size

def clear(self):
raise NotImplementedError
26 changes: 24 additions & 2 deletions zarr/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pickle
import os
import warnings

import uuid

import numpy as np
from numpy.testing import assert_array_equal, assert_array_almost_equal
Expand All @@ -16,7 +16,7 @@

from zarr.storage import (DirectoryStore, init_array, init_group, NestedDirectoryStore,
DBMStore, LMDBStore, atexit_rmtree, atexit_rmglob,
LRUStoreCache)
LRUStoreCache, GCSStore, atexit_rmgcspath)
from zarr.core import Array
from zarr.errors import PermissionError
from zarr.compat import PY2, text_type, binary_type
Expand Down Expand Up @@ -1698,3 +1698,25 @@ def create_array(read_only=False, **kwargs):
init_array(store, **kwargs)
return Array(store, read_only=read_only, cache_metadata=cache_metadata,
cache_attrs=cache_attrs)


try:
from google.cloud import storage as gcstorage
except ImportError: # pragma: no cover
gcstorage = None


@unittest.skipIf(gcstorage is None, 'google-cloud-storage is not installed')
class TestGCSArray(TestArray):

def create_array(self, read_only=False, **kwargs):
bucket = 'zarr-test'
prefix = uuid.uuid4()
atexit.register(atexit_rmgcspath, bucket, prefix)
store = GCSStore(bucket, prefix)
cache_metadata = kwargs.pop('cache_metadata', True)
cache_attrs = kwargs.pop('cache_attrs', True)
kwargs.setdefault('compressor', Zlib(1))
init_array(store, **kwargs)
return Array(store, read_only=read_only, cache_metadata=cache_metadata,
cache_attrs=cache_attrs)
30 changes: 29 additions & 1 deletion zarr/tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import array
import shutil
import os
import uuid


import numpy as np
Expand All @@ -19,7 +20,8 @@
DirectoryStore, ZipStore, init_group, group_meta_key,
getsize, migrate_1to2, TempStore, atexit_rmtree,
NestedDirectoryStore, default_compressor, DBMStore,
LMDBStore, atexit_rmglob, LRUStoreCache)
LMDBStore, atexit_rmglob, LRUStoreCache, GCSStore,
atexit_rmgcspath)
from zarr.meta import (decode_array_metadata, encode_array_metadata, ZARR_FORMAT,
decode_group_metadata, encode_group_metadata)
from zarr.compat import PY2
Expand Down Expand Up @@ -1235,3 +1237,29 @@ def test_format_compatibility():
else:
assert compressor.codec_id == z.compressor.codec_id
assert compressor.get_config() == z.compressor.get_config()


try:
from google.cloud import storage as gcstorage
# cleanup function

except ImportError: # pragma: no cover
gcstorage = None


@unittest.skipIf(gcstorage is None, 'google-cloud-storage is not installed')
class TestGCSStore(StoreTests, unittest.TestCase):

def create_store(self):
# would need to be replaced with a dedicated test bucket
bucket = 'zarr-test'
prefix = uuid.uuid4()
atexit.register(atexit_rmgcspath, bucket, prefix)
store = GCSStore(bucket, prefix)
return store

def test_context_manager(self):
with self.create_store() as store:
store['foo'] = b'bar'
store['baz'] = b'qux'
assert 2 == len(store)