Skip to content

Commit

Permalink
Merge pull request #5 from janelia-cellmap/s3_support
Browse files Browse the repository at this point in the history
support s3
  • Loading branch information
rhoadesScholar authored Apr 16, 2024
2 parents 84a9910 + ff4c78f commit af4b878
Showing 1 changed file with 34 additions and 7 deletions.
41 changes: 34 additions & 7 deletions funlib/persistence/arrays/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,41 @@
from funlib.geometry import Coordinate, Roi

import zarr
from zarr.n5 import N5FSStore
import h5py
import json
import logging
import os
import shutil
from typing import Optional, Union
from typing import Optional, Union, Sequence

logger = logging.getLogger(__name__)


def get_url(node: Union[zarr.Group, zarr.Array]) -> str:
store = node.store
if hasattr(store, "path"):
if hasattr(store, "fs"):
if isinstance(store.fs.protocol, Sequence):
protocol = store.fs.protocol[0]
else:
protocol = store.fs.protocol
else:
protocol = "file"

# fsstore keeps the protocol in the path, but not s3store
if "://" in store.path:
store_path = store.path.split("://")[-1]
else:
store_path = store.path
return f"{protocol}://{store_path}"
else:
raise ValueError(
f"The store associated with this object has type {type(store)}, which "
"cannot be resolved to a url"
)


def separate_store_path(store, path):
"""
sometimes you can pass a total os path to node, leading to
Expand Down Expand Up @@ -49,12 +74,14 @@ def access_parent(node):
zarr.hierarchy.Group : parent group that contains input group/array
"""

store_path, node_path = separate_store_path(node.store.path, node.path)
path = get_url(node)

store_path, node_path = separate_store_path(path, node.path)
if node_path == "":
raise RuntimeError(
f"{node.name} is in the root group of the {node.store.path} store."
)
raise RuntimeError(f"{node.name} is in the root group of the {path} store.")
else:
if store_path.endswith(".n5"):
store_path = N5FSStore(store_path)
return zarr.open(store=store_path, path=os.path.split(node_path)[0], mode="r")


Expand Down Expand Up @@ -212,7 +239,7 @@ def check_for_attrs_multiscale(ds, multiscale_group, multiscales):
if multiscales is not None:
logger.info("Found multiscales attributes")
scale = os.path.relpath(
separate_store_path(ds.store.path, ds.path)[1], multiscale_group.path
separate_store_path(get_url(ds), ds.path)[1], multiscale_group.path
)
if isinstance(ds.store, (zarr.n5.N5Store, zarr.n5.N5FSStore)):
for level in multiscales[0]["datasets"]:
Expand Down Expand Up @@ -399,7 +426,7 @@ def open_ds(filename: str, ds_name: str, mode: str = "r") -> Array:

elif filename.endswith(".n5"):
logger.debug("opening N5 dataset %s in %s", ds_name, filename)
ds = zarr.open(filename, mode=mode)[ds_name]
ds = zarr.open(N5FSStore(filename), mode=mode)[ds_name]

voxel_size, offset = _read_voxel_size_offset(ds, "F")
shape = Coordinate(ds.shape[-len(voxel_size) :])
Expand Down

0 comments on commit af4b878

Please sign in to comment.