Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

zarr-python v3 compatibility #516

Draft
wants to merge 39 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
39722e7
Save progress for next week
mpiannucci Oct 4, 2024
d3c7e37
Bump zarr python version
mpiannucci Oct 5, 2024
25d7d14
Get some tests working others failing
mpiannucci Oct 5, 2024
ffe5f9d
get through single hdf to zarr
mpiannucci Oct 8, 2024
5aef233
Save progress
mpiannucci Oct 8, 2024
b9323d2
Cleanup, almost working with hdf
mpiannucci Oct 9, 2024
0f17119
Closer...
mpiannucci Oct 9, 2024
5c8806b
Updating tests
mpiannucci Oct 9, 2024
80fedcd
reorganize
mpiannucci Oct 10, 2024
1f69a0b
Save progress
mpiannucci Oct 10, 2024
d556e52
Refactor to clean things up
mpiannucci Oct 10, 2024
b27e64c
Fix circular import
mpiannucci Oct 10, 2024
41d6e8e
Iterate
mpiannucci Oct 10, 2024
7ade1a6
Change zarr dep
mpiannucci Oct 10, 2024
492ddee
More conversion
mpiannucci Oct 10, 2024
6e5741c
Specify zarr version
mpiannucci Oct 15, 2024
c0316ac
Working remote hdf tests
mpiannucci Oct 23, 2024
59bd36c
Working grib impl
mpiannucci Oct 23, 2024
187ced2
Add back commented out code
mpiannucci Oct 23, 2024
690ed21
Make grib codec a compressor since its bytes to array
mpiannucci Oct 23, 2024
5019b15
Switch back
mpiannucci Oct 23, 2024
d96cf46
Add first pass at grib zarr 3 codec
mpiannucci Oct 26, 2024
cbcb720
Fix typing
mpiannucci Oct 29, 2024
b88655f
Fix some broken tests; use async filesystem wrapper
moradology Nov 6, 2024
73eaf33
Implement zarr3 compatibility for grib
moradology Nov 20, 2024
3757199
Use zarr3 stores directly; avoid use of internal fs
moradology Nov 21, 2024
9444ff8
Merge pull request #4 from moradology/fix/zarr3-grib-tests
mpiannucci Nov 26, 2024
d8848ce
Forward
mpiannucci Nov 26, 2024
1fa294e
More
mpiannucci Nov 26, 2024
543178d
Figure out async wrapper
mpiannucci Nov 26, 2024
96b56cd
Closer on hdf5
mpiannucci Nov 26, 2024
0808b05
netcdf but failing
mpiannucci Nov 26, 2024
aef006e
grib passing
mpiannucci Nov 26, 2024
d9bf0dd
Fix inline test
mpiannucci Nov 26, 2024
884fc68
More
mpiannucci Nov 26, 2024
1145f45
standardize compressor name
mpiannucci Nov 27, 2024
94ec479
Fix one more hdf test
mpiannucci Nov 27, 2024
a9693d1
Small tweaks
mpiannucci Nov 27, 2024
7e9112a
Hide fsspec import where necessary
mpiannucci Nov 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 80 additions & 3 deletions kerchunk/codecs.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import ast
from dataclasses import dataclass
import io
from typing import Self, TYPE_CHECKING

import numcodecs
from numcodecs.abc import Codec
import numpy as np
import threading
import zlib
from zarr.core.array_spec import ArraySpec
from zarr.abc.codec import ArrayBytesCodec
from zarr.core.buffer import Buffer, NDArrayLike, NDBuffer
from zarr.core.common import JSON, parse_enum, parse_named_configuration
from zarr.registry import register_codec


class FillStringsCodec(Codec):
Expand Down Expand Up @@ -115,6 +122,78 @@ def decode(self, buf, out=None):
numcodecs.register_codec(GRIBCodec, "grib")


@dataclass(frozen=True)
class GRIBZarrCodec(ArrayBytesCodec):
eclock = threading.RLock()

var: str
dtype: np.dtype

def __init__(self, *, var: str, dtype: np.dtype) -> None:
object.__setattr__(self, "var", var)
object.__setattr__(self, "dtype", dtype)

@classmethod
def from_dict(cls, data: dict[str, JSON]) -> Self:
_, configuration_parsed = parse_named_configuration(
data, "bytes", require_configuration=True
)
configuration_parsed = configuration_parsed or {}
return cls(**configuration_parsed) # type: ignore[arg-type]

def to_dict(self) -> dict[str, JSON]:
if self.endian is None:
return {"name": "grib"}
else:
return {
"name": "grib",
"configuration": {"var": self.var, "dtype": self.dtype},
}

async def _decode_single(
self,
chunk_bytes: Buffer,
chunk_spec: ArraySpec,
) -> NDBuffer:
assert isinstance(chunk_bytes, Buffer)
import eccodes

if self.var in ["latitude", "longitude"]:
var = self.var + "s"
dt = self.dtype or "float64"
else:
var = "values"
dt = self.dtype or "float32"

with self.eclock:
mid = eccodes.codes_new_from_message(chunk_bytes.to_bytes())
try:
data = eccodes.codes_get_array(mid, var)
missingValue = eccodes.codes_get_string(mid, "missingValue")
if var == "values" and missingValue:
data[data == float(missingValue)] = np.nan
return data.astype(dt, copy=False)

finally:
eccodes.codes_release(mid)

async def _encode_single(
self,
chunk_array: NDBuffer,
chunk_spec: ArraySpec,
) -> Buffer | None:
# This is a one way codec
raise NotImplementedError

def compute_encoded_size(
self, input_byte_length: int, _chunk_spec: ArraySpec
) -> int:
raise NotImplementedError


register_codec("grib", GRIBZarrCodec)


class AsciiTableCodec(numcodecs.abc.Codec):
"""Decodes ASCII-TABLE extensions in FITS files"""

Expand Down Expand Up @@ -166,7 +245,6 @@ def decode(self, buf, out=None):
arr2 = np.empty((self.nrow,), dtype=dt_out)
heap = buf[arr.nbytes :]
for name in dt_out.names:

if dt_out[name] == "O":
dt = np.dtype(self.ftypes[self.types[name]])
counts = arr[name][:, 0]
Expand Down Expand Up @@ -244,8 +322,7 @@ def encode(self, buf):
class ZlibCodec(Codec):
codec_id = "zlib"

def __init__(self):
...
def __init__(self): ...

def decode(self, data, out=None):
if out:
Expand Down
63 changes: 39 additions & 24 deletions kerchunk/combine.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import collections.abc
import logging
import re
Expand All @@ -10,8 +11,9 @@
import numcodecs
import ujson
import zarr
from zarr.core.buffer.core import default_buffer_prototype

from kerchunk.utils import consolidate
from kerchunk.utils import consolidate, fs_as_store, translate_refs_serializable

logger = logging.getLogger("kerchunk.combine")

Expand Down Expand Up @@ -199,11 +201,12 @@ def append(
remote_protocol=remote_protocol,
remote_options=remote_options,
target_options=target_options,
asynchronous=True
)
ds = xr.open_dataset(
fs.get_mapper(), engine="zarr", backend_kwargs={"consolidated": False}
)
z = zarr.open(fs.get_mapper())
z = zarr.open(fs.get_mapper(), zarr_format=2)
mzz = MultiZarrToZarr(
path,
out=fs.references, # dict or parquet/lazy
Expand Down Expand Up @@ -264,7 +267,7 @@ def fss(self):
self._paths = []
for of in fsspec.open_files(self.path, **self.target_options):
self._paths.append(of.full_name)
fs = fsspec.core.url_to_fs(self.path[0], **self.target_options)[0]
fs = fsspec.core.url_to_fs(self.path[0], asynchronous=True, **self.target_options)[0]
try:
# JSON path
fo_list = fs.cat(self.path)
Expand Down Expand Up @@ -348,6 +351,16 @@ def _get_value(self, index, z, var, fn=None):
logger.debug("Decode: %s -> %s", (selector, index, var, fn), o)
return o

async def _read_meta_files(self, m, files):
"""Helper to load multiple metadata files asynchronously"""
res = {}
for fn in files:
exists = await m.exists(fn)
if exists:
content = await m.get(fn, prototype=default_buffer_prototype())
res[fn] = ujson.dumps(ujson.loads(content.to_bytes()))
return res

def first_pass(self):
"""Accumulate the set of concat coords values across all inputs"""

Expand All @@ -360,7 +373,8 @@ def first_pass(self):
fs._dircache_from_items()

logger.debug("First pass: %s", i)
z = zarr.open_group(fs.get_mapper(""))
z_store = fs_as_store(fs, read_only=False)
z = zarr.open_group(z_store, zarr_format=2)
for var in self.concat_dims:
value = self._get_value(i, z, var, fn=self._paths[i])
if isinstance(value, np.ndarray):
Expand All @@ -386,16 +400,16 @@ def store_coords(self):
Write coordinate arrays into the output
"""
kv = {}
store = zarr.storage.KVStore(kv)
group = zarr.open(store)
m = self.fss[0].get_mapper("")
z = zarr.open(m)
store = zarr.storage.MemoryStore(kv)
group = zarr.open_group(store, zarr_format=2)
m = fs_as_store(self.fss[0], read_only=False)
z = zarr.open(m, zarr_format=2)
for k, v in self.coos.items():
if k == "var":
# The names of the variables to write in the second pass, not a coordinate
continue
# parametrize the threshold value below?
compression = numcodecs.Zstd() if len(v) > 100 else None
compressor = numcodecs.Zstd() if len(v) > 100 else None
kw = {}
if self.cf_units and k in self.cf_units:
if "M" not in self.coo_dtypes.get(k, ""):
Expand All @@ -420,11 +434,12 @@ def store_coords(self):
elif k in z:
# Fall back to existing fill value
kw["fill_value"] = z[k].fill_value
arr = group.create_dataset(
arr = group.create_array(
name=k,
data=data,
overwrite=True,
compressor=compression,
shape=data.shape,
exists_ok=True,
compressor=compressor,
dtype=self.coo_dtypes.get(k, data.dtype),
**kw,
)
Expand All @@ -441,10 +456,9 @@ def store_coords(self):
# TODO: rewrite .zarray/.zattrs with ujson to save space. Maybe make them by hand anyway.
self.out.update(kv)
logger.debug("Written coordinates")
for fn in [".zgroup", ".zattrs"]:
# top-level group attributes from first input
if fn in m:
self.out[fn] = ujson.dumps(ujson.loads(m[fn]))

metadata = asyncio.run(self._read_meta_files(m, [".zgroup", ".zattrs"]))
self.out.update(metadata)
logger.debug("Written global metadata")
self.done.add(2)

Expand All @@ -460,8 +474,8 @@ def second_pass(self):

for i, fs in enumerate(self.fss):
to_download = {}
m = fs.get_mapper("")
z = zarr.open(m)
m = fs_as_store(fs, read_only=False)
z = zarr.open(m, zarr_format=2)

if no_deps is None:
# done first time only
Expand Down Expand Up @@ -491,9 +505,8 @@ def second_pass(self):
if f"{v}/.zgroup" in fns:
# recurse into groups - copy meta, add to dirs to process and don't look
# for references in this dir
self.out[f"{v}/.zgroup"] = m[f"{v}/.zgroup"]
if f"{v}/.zattrs" in fns:
self.out[f"{v}/.zattrs"] = m[f"{v}/.zattrs"]
metadata = asyncio.run(self._read_meta_files(m, [f"{v}/.zgroup", f"{v}/.zattrs"]))
self.out.update(metadata)
dirs.extend([f for f in fns if not f.startswith(f"{v}/.z")])
continue
if v in self.identical_dims:
Expand All @@ -504,8 +517,9 @@ def second_pass(self):
self.out[k] = fs.references[k]
continue
logger.debug("Second pass: %s, %s", i, v)

zarray = ujson.loads(m[f"{v}/.zarray"])

zarray = asyncio.run(self._read_meta_files(m, [f"{v}/.zarray"]))[f"{v}/.zarray"]
zarray = ujson.loads(zarray)
if v not in chunk_sizes:
chunk_sizes[v] = zarray["chunks"]
elif chunk_sizes[v] != zarray["chunks"]:
Expand All @@ -516,7 +530,8 @@ def second_pass(self):
chunks so far: {zarray["chunks"]}"""
)
chunks = chunk_sizes[v]
zattrs = ujson.loads(m.get(f"{v}/.zattrs", "{}"))
zattr_meta = asyncio.run(self._read_meta_files(m, [f"{v}/.zattrs"]))
zattrs = ujson.loads(zattr_meta.get(f"{v}/.zattrs", {}))
coords = zattrs.get("_ARRAY_DIMENSIONS", [])
if zarray["shape"] and not coords:
coords = list("ikjlm")[: len(zarray["shape"])]
Expand Down
11 changes: 6 additions & 5 deletions kerchunk/fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from fsspec.implementations.reference import LazyReferenceMapper


from kerchunk.utils import class_factory
from kerchunk.utils import class_factory, dict_to_store
from kerchunk.codecs import AsciiTableCodec, VarArrCodec

try:
Expand Down Expand Up @@ -72,7 +72,8 @@ def process_file(

storage_options = storage_options or {}
out = out or {}
g = zarr.open(out)
store = dict_to_store(out)
g = zarr.open_group(store=store, zarr_format=2)

with fsspec.open(url, mode="rb", **storage_options) as f:
infile = fits.open(f, do_not_scale_image_data=True)
Expand Down Expand Up @@ -150,7 +151,7 @@ def process_file(
for name in dtype.names
if hdu.columns[name].format.startswith(("P", "Q"))
}
kwargs["object_codec"] = VarArrCodec(
kwargs["compressor"] = VarArrCodec(
str(dtype), str(dt2), nrows, types
)
dtype = dt2
Expand All @@ -164,7 +165,7 @@ def process_file(
# TODO: we could sub-chunk on biggest dimension
name = hdu.name or str(ext)
arr = g.empty(
name, dtype=dtype, shape=shape, chunks=shape, compression=None, **kwargs
name=name, dtype=dtype, shape=shape, chunks=shape, zarr_format=2, **kwargs
)
arr.attrs.update(
{
Expand Down Expand Up @@ -248,7 +249,7 @@ def add_wcs_coords(hdu, zarr_group=None, dataset=None, dtype="float32"):
}
if zarr_group is not None:
arr = zarr_group.empty(
name, shape=shape, chunks=shape, overwrite=True, dtype=dtype
name, shape=shape, chunks=shape, dtype=dtype, exists_ok=True
)
arr.attrs.update(attrs)
arr[:] = world_coord.value.reshape(shape)
Expand Down
Loading