From ce83f63b19bd8085210d4abe6b0a8ec1cf073c46 Mon Sep 17 00:00:00 2001 From: David Stansby Date: Sat, 9 Nov 2024 19:51:53 +0000 Subject: [PATCH] Some zarr3 typing fixes --- .pre-commit-config.yaml | 2 +- numcodecs/tests/test_zarr3.py | 33 +++++++++++++++++++---------- numcodecs/zarr3.py | 39 +++++++++++++++++------------------ 3 files changed, 42 insertions(+), 32 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 96d56933..37407f24 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,4 +31,4 @@ repos: hooks: - id: mypy args: [--config-file, pyproject.toml] - additional_dependencies: [numpy, pytest, zfpy] + additional_dependencies: [numpy, pytest, zfpy, 'zarr==3.0.0b1'] diff --git a/numcodecs/tests/test_zarr3.py b/numcodecs/tests/test_zarr3.py index ec1a398e..afb77136 100644 --- a/numcodecs/tests/test_zarr3.py +++ b/numcodecs/tests/test_zarr3.py @@ -1,11 +1,19 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import numpy as np import pytest -zarr = pytest.importorskip("zarr") +if not TYPE_CHECKING: + zarr = pytest.importorskip("zarr") +else: + import zarr + +import zarr.storage +from zarr.core.common import JSON -import numcodecs.zarr3 # noqa: E402 +import numcodecs.zarr3 pytestmark = [ pytest.mark.skipif(zarr.__version__ < "3.0.0", reason="zarr 3.0.0 or later is required"), @@ -17,7 +25,6 @@ get_codec_class = zarr.registry.get_codec_class Array = zarr.Array -JSON = zarr.core.common.JSON BytesCodec = zarr.codecs.BytesCodec Store = zarr.abc.store.Store MemoryStore = zarr.storage.MemoryStore @@ -28,7 +35,7 @@ @pytest.fixture -def store() -> Store: +def store() -> StorePath: return StorePath(MemoryStore(mode="w")) @@ -43,6 +50,8 @@ def test_entry_points(codec_class: type[numcodecs.zarr3._NumcodecsCodec]): @pytest.mark.parametrize("codec_class", ALL_CODECS) def test_docstring(codec_class: type[numcodecs.zarr3._NumcodecsCodec]): + if codec_class.__doc__ is None: + pytest.skip() assert "See :class:`numcodecs." in codec_class.__doc__ @@ -59,7 +68,7 @@ def test_docstring(codec_class: type[numcodecs.zarr3._NumcodecsCodec]): numcodecs.zarr3.Shuffle, ], ) -def test_generic_codec_class(store: Store, codec_class: type[numcodecs.zarr3._NumcodecsCodec]): +def test_generic_codec_class(store: StorePath, codec_class: type[numcodecs.zarr3._NumcodecsCodec]): data = np.arange(0, 256, dtype="uint16").reshape((16, 16)) with pytest.warns(UserWarning, match=EXPECTED_WARNING_STR): @@ -92,7 +101,9 @@ def test_generic_codec_class(store: Store, codec_class: type[numcodecs.zarr3._Nu ], ) def test_generic_filter( - store: Store, codec_class: type[numcodecs.zarr3._NumcodecsCodec], codec_config: dict[str, JSON] + store: StorePath, + codec_class: type[numcodecs.zarr3._NumcodecsCodec], + codec_config: dict[str, JSON], ): data = np.linspace(0, 10, 256, dtype="float32").reshape((16, 16)) @@ -114,7 +125,7 @@ def test_generic_filter( np.testing.assert_array_equal(data, a[:, :]) -def test_generic_filter_bitround(store: Store): +def test_generic_filter_bitround(store: StorePath): data = np.linspace(0, 1, 256, dtype="float32").reshape((16, 16)) with pytest.warns(UserWarning, match=EXPECTED_WARNING_STR): @@ -132,7 +143,7 @@ def test_generic_filter_bitround(store: Store): assert np.allclose(data, a[:, :], atol=0.1) -def test_generic_filter_quantize(store: Store): +def test_generic_filter_quantize(store: StorePath): data = np.linspace(0, 10, 256, dtype="float32").reshape((16, 16)) with pytest.warns(UserWarning, match=EXPECTED_WARNING_STR): @@ -150,7 +161,7 @@ def test_generic_filter_quantize(store: Store): assert np.allclose(data, a[:, :], atol=0.001) -def test_generic_filter_packbits(store: Store): +def test_generic_filter_packbits(store: StorePath): data = np.zeros((16, 16), dtype="bool") data[0:4, :] = True @@ -189,7 +200,7 @@ def test_generic_filter_packbits(store: Store): numcodecs.zarr3.JenkinsLookup3, ], ) -def test_generic_checksum(store: Store, codec_class: type[numcodecs.zarr3._NumcodecsCodec]): +def test_generic_checksum(store: StorePath, codec_class: type[numcodecs.zarr3._NumcodecsCodec]): data = np.linspace(0, 10, 256, dtype="float32").reshape((16, 16)) with pytest.warns(UserWarning, match=EXPECTED_WARNING_STR): @@ -208,7 +219,7 @@ def test_generic_checksum(store: Store, codec_class: type[numcodecs.zarr3._Numco @pytest.mark.parametrize("codec_class", [numcodecs.zarr3.PCodec, numcodecs.zarr3.ZFPY]) -def test_generic_bytes_codec(store: Store, codec_class: type[numcodecs.zarr3._NumcodecsCodec]): +def test_generic_bytes_codec(store: StorePath, codec_class: type[numcodecs.zarr3._NumcodecsCodec]): try: codec_class()._codec # noqa: B018 except ValueError as e: diff --git a/numcodecs/zarr3.py b/numcodecs/zarr3.py index 811ab501..119afc00 100644 --- a/numcodecs/zarr3.py +++ b/numcodecs/zarr3.py @@ -26,7 +26,6 @@ import asyncio import math -from collections.abc import Callable from dataclasses import dataclass, replace from functools import cached_property, partial from typing import Any, Self, TypeVar @@ -76,7 +75,7 @@ class _NumcodecsCodec: codec_name: str codec_config: dict[str, JSON] - def __init__(self, **codec_config: dict[str, JSON]) -> None: + def __init__(self, **codec_config: JSON) -> None: if not self.codec_name: raise ValueError( "The codec name needs to be supplied through the `codec_name` attribute." @@ -106,7 +105,7 @@ def from_dict(cls, data: dict[str, JSON]) -> Self: codec_config = _parse_codec_configuration(data) return cls(**codec_config) - def to_dict(self) -> JSON: + def to_dict(self) -> dict[str, JSON]: codec_config = self.codec_config.copy() return { "name": self.codec_name, @@ -118,7 +117,7 @@ def compute_encoded_size(self, input_byte_length: int, chunk_spec: ArraySpec) -> class _NumcodecsBytesBytesCodec(_NumcodecsCodec, BytesBytesCodec): - def __init__(self, **codec_config: dict[str, JSON]) -> None: + def __init__(self, **codec_config: JSON) -> None: super().__init__(**codec_config) async def _decode_single(self, chunk_bytes: Buffer, chunk_spec: ArraySpec) -> Buffer: @@ -140,7 +139,7 @@ async def _encode_single(self, chunk_bytes: Buffer, chunk_spec: ArraySpec) -> Bu class _NumcodecsArrayArrayCodec(_NumcodecsCodec, ArrayArrayCodec): - def __init__(self, **codec_config: dict[str, JSON]) -> None: + def __init__(self, **codec_config: JSON) -> None: super().__init__(**codec_config) async def _decode_single(self, chunk_array: NDBuffer, chunk_spec: ArraySpec) -> NDBuffer: @@ -155,7 +154,7 @@ async def _encode_single(self, chunk_array: NDBuffer, chunk_spec: ArraySpec) -> class _NumcodecsArrayBytesCodec(_NumcodecsCodec, ArrayBytesCodec): - def __init__(self, **codec_config: dict[str, JSON]) -> None: + def __init__(self, **codec_config: JSON) -> None: super().__init__(**codec_config) async def _decode_single(self, chunk_buffer: Buffer, chunk_spec: ArraySpec) -> NDBuffer: @@ -179,7 +178,7 @@ def _add_docstring(cls: type[T], ref_class_name: str) -> type[T]: return cls -def _add_docstring_wrapper(ref_class_name: str) -> Callable[[type[T]], type[T]]: +def _add_docstring_wrapper(ref_class_name: str) -> partial: return partial(_add_docstring, ref_class_name=ref_class_name) @@ -190,7 +189,7 @@ def _make_bytes_bytes_codec(codec_name: str, cls_name: str) -> type[_NumcodecsBy class _Codec(_NumcodecsBytesBytesCodec): codec_name = _codec_name - def __init__(self, **codec_config: dict[str, JSON]) -> None: + def __init__(self, **codec_config: JSON) -> None: super().__init__(**codec_config) _Codec.__name__ = cls_name @@ -204,7 +203,7 @@ def _make_array_array_codec(codec_name: str, cls_name: str) -> type[_NumcodecsAr class _Codec(_NumcodecsArrayArrayCodec): codec_name = _codec_name - def __init__(self, **codec_config: dict[str, JSON]) -> None: + def __init__(self, **codec_config: JSON) -> None: super().__init__(**codec_config) _Codec.__name__ = cls_name @@ -218,7 +217,7 @@ def _make_array_bytes_codec(codec_name: str, cls_name: str) -> type[_NumcodecsAr class _Codec(_NumcodecsArrayBytesCodec): codec_name = _codec_name - def __init__(self, **codec_config: dict[str, JSON]) -> None: + def __init__(self, **codec_config: JSON) -> None: super().__init__(**codec_config) _Codec.__name__ = cls_name @@ -232,7 +231,7 @@ def _make_checksum_codec(codec_name: str, cls_name: str) -> type[_NumcodecsBytes class _ChecksumCodec(_NumcodecsBytesBytesCodec): codec_name = _codec_name - def __init__(self, **codec_config: dict[str, JSON]) -> None: + def __init__(self, **codec_config: JSON) -> None: super().__init__(**codec_config) def compute_encoded_size(self, input_byte_length: int, chunk_spec: ArraySpec) -> int: @@ -256,10 +255,10 @@ def compute_encoded_size(self, input_byte_length: int, chunk_spec: ArraySpec) -> class Shuffle(_NumcodecsBytesBytesCodec): codec_name = f"{CODEC_PREFIX}shuffle" - def __init__(self, **codec_config: dict[str, JSON]) -> None: + def __init__(self, **codec_config: JSON) -> None: super().__init__(**codec_config) - def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: + def evolve_from_array_spec(self, array_spec: ArraySpec) -> Shuffle: if array_spec.dtype.itemsize != self.codec_config.get("elementsize"): return Shuffle(**{**self.codec_config, "elementsize": array_spec.dtype.itemsize}) return self # pragma: no cover @@ -276,7 +275,7 @@ def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: class FixedScaleOffset(_NumcodecsArrayArrayCodec): codec_name = f"{CODEC_PREFIX}fixedscaleoffset" - def __init__(self, **codec_config: dict[str, JSON]) -> None: + def __init__(self, **codec_config: JSON) -> None: super().__init__(**codec_config) def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec: @@ -284,7 +283,7 @@ def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec: return replace(chunk_spec, dtype=np.dtype(astype)) return chunk_spec - def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: + def evolve_from_array_spec(self, array_spec: ArraySpec) -> FixedScaleOffset: if str(array_spec.dtype) != self.codec_config.get("dtype"): return FixedScaleOffset(**{**self.codec_config, "dtype": str(array_spec.dtype)}) return self @@ -294,10 +293,10 @@ def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: class Quantize(_NumcodecsArrayArrayCodec): codec_name = f"{CODEC_PREFIX}quantize" - def __init__(self, **codec_config: dict[str, JSON]) -> None: + def __init__(self, **codec_config: JSON) -> None: super().__init__(**codec_config) - def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: + def evolve_from_array_spec(self, array_spec: ArraySpec) -> Quantize: if str(array_spec.dtype) != self.codec_config.get("dtype"): return Quantize(**{**self.codec_config, "dtype": str(array_spec.dtype)}) return self @@ -307,7 +306,7 @@ def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: class PackBits(_NumcodecsArrayArrayCodec): codec_name = f"{CODEC_PREFIX}packbits" - def __init__(self, **codec_config: dict[str, JSON]) -> None: + def __init__(self, **codec_config: JSON) -> None: super().__init__(**codec_config) def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec: @@ -326,13 +325,13 @@ def validate(self, *, dtype: np.dtype[Any], **_kwargs) -> None: class AsType(_NumcodecsArrayArrayCodec): codec_name = f"{CODEC_PREFIX}astype" - def __init__(self, **codec_config: dict[str, JSON]) -> None: + def __init__(self, **codec_config: JSON) -> None: super().__init__(**codec_config) def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec: return replace(chunk_spec, dtype=np.dtype(self.codec_config["encode_dtype"])) - def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: + def evolve_from_array_spec(self, array_spec: ArraySpec) -> AsType: decode_dtype = self.codec_config.get("decode_dtype") if str(array_spec.dtype) != decode_dtype: return AsType(**{**self.codec_config, "decode_dtype": str(array_spec.dtype)})