Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update static typing #213

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ jobs:
conda env list
conda list

- name: Type check
run: |
mypy virtualizarr

- name: Running Tests
run: |
python -m pytest ./virtualizarr --run-network-tests --cov=./ --cov-report=xml --verbose
Expand Down
24 changes: 0 additions & 24 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,3 @@ repos:
args: [ --fix ]
# Run the formatter.
- id: ruff-format

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.10.0
hooks:
- id: mypy
# Copied from setup.cfg
exclude: "properties|asv_bench|docs"
additional_dependencies: [
# Type stubs
types-python-dateutil,
types-setuptools,
types-PyYAML,
types-pytz,
# Dependencies that are typed
numpy,
typing-extensions>=4.1.0,
]
# run this occasionally, ref discussion https://github.com/pydata/xarray/pull/3194
# - repo: https://github.com/asottile/pyupgrade
# rev: v3.15.2
# hooks:
# - id: pyupgrade
# args:
# - "--py310-plus"
2 changes: 2 additions & 0 deletions ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ dependencies:
# Testing
- codecov
- pre-commit
- mypy
- ruff
- pandas-stubs
- pytest-mypy
- pytest-cov
- pytest
Expand Down
34 changes: 23 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,20 @@ dependencies = [
[project.optional-dependencies]
test = [
"codecov",
"fastparquet",
TomAugspurger marked this conversation as resolved.
Show resolved Hide resolved
"fsspec",
"h5py",
"mypy",
"netcdf4",
"pandas-stubs",
"pooch",
"pre-commit",
"ruff",
"pytest-mypy",
"pytest-cov",
"pytest-mypy",
"pytest",
"pooch",
"scipy",
"netcdf4",
"fsspec",
"ruff",
"s3fs",
"fastparquet",
"h5py"
"scipy",
]


Expand All @@ -70,12 +72,22 @@ exclude = ["docs", "tests", "tests.*", "docs.*"]
[tool.setuptools.package-data]
datatree = ["py.typed"]



[mypy]
[tool.mypy]
files = "virtualizarr/**/*.py"
show_error_codes = true

[[tool.mypy.overrides]]
module = "fsspec.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "numcodecs.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "kerchunk.*"
ignore_missing_imports = true

[tool.ruff]
# Same as Black.
line-length = 88
Expand Down
6 changes: 4 additions & 2 deletions virtualizarr/kerchunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class FileType(AutoName):
tiff = auto()
fits = auto()
zarr = auto()
zarr_v3 = auto()


class NumpyEncoder(json.JSONEncoder):
Expand Down Expand Up @@ -223,7 +224,7 @@ def dataset_to_kerchunk_refs(ds: xr.Dataset) -> KerchunkStoreRefs:

all_arr_refs = {}
for var_name, var in ds.variables.items():
arr_refs = variable_to_kerchunk_arr_refs(var, var_name)
arr_refs = variable_to_kerchunk_arr_refs(var, str(var_name))
TomAugspurger marked this conversation as resolved.
Show resolved Hide resolved

prepended_with_var_name = {
f"{var_name}/{key}": val for key, val in arr_refs.items()
Expand All @@ -233,7 +234,7 @@ def dataset_to_kerchunk_refs(ds: xr.Dataset) -> KerchunkStoreRefs:

zattrs = ds.attrs
if ds.coords:
coord_names = list(ds.coords)
coord_names = [str(x) for x in ds.coords]
# this weird concatenated string instead of a list of strings is inconsistent with how other features in the kerchunk references format are stored
# see https://github.com/zarr-developers/VirtualiZarr/issues/105#issuecomment-2187266739
zattrs["coordinates"] = " ".join(coord_names)
Expand Down Expand Up @@ -302,6 +303,7 @@ def variable_to_kerchunk_arr_refs(var: xr.Variable, var_name: str) -> KerchunkAr
shape=np_arr.shape,
dtype=np_arr.dtype,
order="C",
fill_value=None,
)

zarray_dict = zarray.to_kerchunk_json()
Expand Down
2 changes: 1 addition & 1 deletion virtualizarr/manifests/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs) -> Any:
return _isnan(self.shape)
return NotImplemented

def __array__(self) -> np.ndarray:
def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray:
TomAugspurger marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError(
"ManifestArrays can't be converted into numpy arrays or pandas Index objects"
)
Expand Down
4 changes: 2 additions & 2 deletions virtualizarr/manifests/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def to_kerchunk(self) -> tuple[str, int, int]:
"""Write out in the format that kerchunk uses for chunk entries."""
return (self.path, self.offset, self.length)

def dict(self) -> ChunkDictEntry:
def dict(self) -> ChunkDictEntry: # type: ignore[override]
return ChunkDictEntry(path=self.path, offset=self.offset, length=self.length)


Expand Down Expand Up @@ -238,7 +238,7 @@ def __iter__(self) -> Iterator[ChunkKey]:
def __len__(self) -> int:
return self._paths.size

def dict(self) -> ChunkDict:
def dict(self) -> ChunkDict: # type: ignore[override]
"""
Convert the entire manifest to a nested dictionary.

Expand Down
3 changes: 2 additions & 1 deletion virtualizarr/tests/test_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import xarray.testing as xrt

from virtualizarr import ManifestArray, open_virtual_dataset
from virtualizarr.kerchunk import FileType
from virtualizarr.manifests.manifest import ChunkManifest
from virtualizarr.zarr import dataset_to_zarr, metadata_from_zarr_json

Expand Down Expand Up @@ -40,7 +41,7 @@ def isconfigurable(value: dict) -> bool:
def test_zarr_v3_roundtrip(tmpdir, vds_with_manifest_arrays: xr.Dataset):
vds_with_manifest_arrays.virtualize.to_zarr(tmpdir / "store.zarr")
roundtrip = open_virtual_dataset(
tmpdir / "store.zarr", filetype="zarr_v3", indexes={}
tmpdir / "store.zarr", filetype=FileType.zarr_v3, indexes={}
)

xrt.assert_identical(roundtrip, vds_with_manifest_arrays)
Expand Down
37 changes: 29 additions & 8 deletions virtualizarr/xarray.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
import os
import warnings
from collections.abc import Iterable, Mapping, MutableMapping
from io import BufferedIOBase
from pathlib import Path
from typing import (
Any,
Callable,
Hashable,
Literal,
Optional,
cast,
overload,
)

import ujson # type: ignore
import xarray as xr
from upath import UPath
from xarray import register_dataset_accessor
from xarray.backends import BackendArray
from xarray.backends import AbstractDataStore, BackendArray
from xarray.coding.times import CFDatetimeCoder
from xarray.core.indexes import Index, PandasIndex
from xarray.core.variable import IndexVariable
Expand All @@ -27,6 +32,8 @@
metadata_from_zarr_json,
)

XArrayOpenT = str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore


class ManifestBackendArray(ManifestArray, BackendArray):
"""Using this prevents xarray from wrapping the KerchunkArray in ExplicitIndexingAdapter etc."""
Expand Down Expand Up @@ -85,6 +92,9 @@ def open_virtual_dataset(
vds
An xarray Dataset containing instances of virtual_array_cls for each variable, or normal lazily indexed arrays for each variable in loadable_variables.
"""
loadable_vars: dict[str, xr.Variable]
virtual_vars: dict[str, xr.Variable]
vars: dict[str, xr.Variable]

if drop_variables is None:
drop_variables = []
Expand Down Expand Up @@ -119,7 +129,11 @@ def open_virtual_dataset(
if virtual_array_class is not ManifestArray:
raise NotImplementedError()

if filetype == "zarr_v3":
# if filetype is user defined, convert to FileType
if filetype is not None:
filetype = FileType(filetype)

if filetype == FileType.zarr_v3:
# TODO is there a neat way of auto-detecting this?
return open_virtual_dataset_from_v3_store(
storepath=filepath, drop_variables=drop_variables, indexes=indexes
Expand Down Expand Up @@ -158,8 +172,13 @@ def open_virtual_dataset(
filepath=filepath, reader_options=reader_options
)

# fpath can be `Any` thanks to fsspec.filesystem(...).open() returning Any.
# We'll (hopefully safely) cast it to what xarray is expecting, but this might let errors through.

ds = xr.open_dataset(
fpath, drop_variables=drop_variables, decode_times=False
cast(XArrayOpenT, fpath),
drop_variables=drop_variables,
decode_times=False,
)

if indexes is None:
Expand All @@ -177,7 +196,7 @@ def open_virtual_dataset(
indexes = dict(**indexes) # for type hinting: to allow mutation

loadable_vars = {
name: var
str(name): var
for name, var in ds.variables.items()
if name in loadable_variables
}
Expand Down Expand Up @@ -265,7 +284,7 @@ def virtual_vars_from_kerchunk_refs(
refs: KerchunkStoreRefs,
drop_variables: list[str] | None = None,
virtual_array_class=ManifestArray,
) -> Mapping[str, xr.Variable]:
) -> dict[str, xr.Variable]:
TomAugspurger marked this conversation as resolved.
Show resolved Hide resolved
"""
Translate a store-level kerchunk reference dict into aaset of xarray Variables containing virtualized arrays.

Expand Down Expand Up @@ -351,7 +370,7 @@ def separate_coords(
vars: Mapping[str, xr.Variable],
indexes: MutableMapping[str, Index],
coord_names: Iterable[str] | None = None,
) -> tuple[Mapping[str, xr.Variable], xr.Coordinates]:
) -> tuple[dict[str, xr.Variable], xr.Coordinates]:
"""
Try to generate a set of coordinates that won't cause xarray to automatically build a pandas.Index for the 1D coordinates.

Expand All @@ -365,7 +384,9 @@ def separate_coords(

# split data and coordinate variables (promote dimension coordinates)
data_vars = {}
coord_vars = {}
coord_vars: dict[
str, tuple[Hashable, Any, dict[Any, Any], dict[Any, Any]] | xr.Variable
] = {}
for name, var in vars.items():
if name in coord_names or var.dims == (name,):
# use workaround to avoid creating IndexVariables described here https://github.com/pydata/xarray/pull/8107#discussion_r1311214263
Expand All @@ -376,7 +397,7 @@ def separate_coords(
if isinstance(var, IndexVariable):
# unless variable actually already is a loaded IndexVariable,
# in which case we need to keep it and add the corresponding indexes explicitly
coord_vars[name] = var
coord_vars[str(name)] = var
# TODO this seems suspect - will it handle datetimes?
indexes[name] = PandasIndex(var, dim1d)
else:
Expand Down
20 changes: 14 additions & 6 deletions virtualizarr/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Literal,
NewType,
Optional,
cast,
)

import numcodecs
Expand All @@ -31,6 +32,7 @@
"ZAttrs", dict[str, Any]
) # just the .zattrs (for one array or for the whole store/group)
FillValueT = bool | str | float | int | list | None
ZARR_FORMAT = Literal[2, 3]

ZARR_DEFAULT_FILL_VALUE: dict[str, FillValueT] = {
# numpy dtypes's hierarchy lets us avoid checking for all the widths
Expand Down Expand Up @@ -72,7 +74,7 @@ class ZArray(BaseModel):
filters: list[dict] | None = None
order: Literal["C", "F"]
shape: tuple[int, ...]
zarr_format: Literal[2, 3] = 2
zarr_format: ZARR_FORMAT = 2

@field_validator("dtype")
@classmethod
Expand Down Expand Up @@ -110,6 +112,10 @@ def from_kerchunk_refs(cls, decoded_arr_refs_zarray) -> "ZArray":
fill_value = np.nan

compressor = decoded_arr_refs_zarray["compressor"]
zarr_format = int(decoded_arr_refs_zarray["zarr_format"])
if zarr_format not in (2, 3):
raise ValueError(f"Zarr format must be 2 or 3, but got {zarr_format}")

return ZArray(
chunks=tuple(decoded_arr_refs_zarray["chunks"]),
compressor=compressor,
Expand All @@ -118,10 +124,10 @@ def from_kerchunk_refs(cls, decoded_arr_refs_zarray) -> "ZArray":
filters=decoded_arr_refs_zarray["filters"],
order=decoded_arr_refs_zarray["order"],
shape=tuple(decoded_arr_refs_zarray["shape"]),
zarr_format=int(decoded_arr_refs_zarray["zarr_format"]),
zarr_format=cast(ZARR_FORMAT, zarr_format),
)

def dict(self) -> dict[str, Any]:
def dict(self) -> dict[str, Any]: # type: ignore
zarray_dict = dict(self)
zarray_dict["dtype"] = encode_dtype(zarray_dict["dtype"])
return zarray_dict
Expand All @@ -135,7 +141,7 @@ def to_kerchunk_json(self) -> str:
def replace(
self,
chunks: Optional[tuple[int, ...]] = None,
compressor: Optional[dict] = None,
compressor: Optional[dict] = None, # type: ignore[valid-type]
dtype: Optional[np.dtype] = None,
fill_value: Optional[float] = None, # float or int?
filters: Optional[list[dict]] = None, # type: ignore[valid-type]
Expand Down Expand Up @@ -251,7 +257,7 @@ def dataset_to_zarr(ds: xr.Dataset, storepath: str) -> None:
group_metadata_file.write(json_dumps(group_metadata))

for name, var in ds.variables.items():
array_dir = _storepath / name
array_dir = _storepath / str(name)
marr = var.data

# TODO move this check outside the writing loop so we don't write an incomplete store on failure?
Expand Down Expand Up @@ -287,7 +293,9 @@ def to_zarr_json(var: xr.Variable, array_dir: Path) -> None:

marr.manifest.to_zarr_json(array_dir / "manifest.json")

metadata = zarr_v3_array_metadata(marr.zarray, list(var.dims), var.attrs)
metadata = zarr_v3_array_metadata(
marr.zarray, [str(x) for x in var.dims], var.attrs
)
with open(array_dir / "zarr.json", "wb") as metadata_file:
metadata_file.write(json_dumps(metadata))

Expand Down
Loading