Skip to content

Commit

Permalink
Some zarr3 typing fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dstansby committed Nov 9, 2024
1 parent 05254a6 commit ce83f63
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 32 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ repos:
hooks:
- id: mypy
args: [--config-file, pyproject.toml]
additional_dependencies: [numpy, pytest, zfpy]
additional_dependencies: [numpy, pytest, zfpy, 'zarr==3.0.0b1']
33 changes: 22 additions & 11 deletions numcodecs/tests/test_zarr3.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
import pytest

zarr = pytest.importorskip("zarr")
if not TYPE_CHECKING:
zarr = pytest.importorskip("zarr")
else:
import zarr

Check warning on line 11 in numcodecs/tests/test_zarr3.py

View check run for this annotation

Codecov / codecov/patch

numcodecs/tests/test_zarr3.py#L11

Added line #L11 was not covered by tests

import zarr.storage
from zarr.core.common import JSON

import numcodecs.zarr3 # noqa: E402
import numcodecs.zarr3

pytestmark = [
pytest.mark.skipif(zarr.__version__ < "3.0.0", reason="zarr 3.0.0 or later is required"),
Expand All @@ -17,7 +25,6 @@

get_codec_class = zarr.registry.get_codec_class
Array = zarr.Array
JSON = zarr.core.common.JSON
BytesCodec = zarr.codecs.BytesCodec
Store = zarr.abc.store.Store
MemoryStore = zarr.storage.MemoryStore
Expand All @@ -28,7 +35,7 @@


@pytest.fixture
def store() -> Store:
def store() -> StorePath:
return StorePath(MemoryStore(mode="w"))


Expand All @@ -43,6 +50,8 @@ def test_entry_points(codec_class: type[numcodecs.zarr3._NumcodecsCodec]):

@pytest.mark.parametrize("codec_class", ALL_CODECS)
def test_docstring(codec_class: type[numcodecs.zarr3._NumcodecsCodec]):
if codec_class.__doc__ is None:
pytest.skip()

Check warning on line 54 in numcodecs/tests/test_zarr3.py

View check run for this annotation

Codecov / codecov/patch

numcodecs/tests/test_zarr3.py#L54

Added line #L54 was not covered by tests
assert "See :class:`numcodecs." in codec_class.__doc__


Expand All @@ -59,7 +68,7 @@ def test_docstring(codec_class: type[numcodecs.zarr3._NumcodecsCodec]):
numcodecs.zarr3.Shuffle,
],
)
def test_generic_codec_class(store: Store, codec_class: type[numcodecs.zarr3._NumcodecsCodec]):
def test_generic_codec_class(store: StorePath, codec_class: type[numcodecs.zarr3._NumcodecsCodec]):
data = np.arange(0, 256, dtype="uint16").reshape((16, 16))

with pytest.warns(UserWarning, match=EXPECTED_WARNING_STR):
Expand Down Expand Up @@ -92,7 +101,9 @@ def test_generic_codec_class(store: Store, codec_class: type[numcodecs.zarr3._Nu
],
)
def test_generic_filter(
store: Store, codec_class: type[numcodecs.zarr3._NumcodecsCodec], codec_config: dict[str, JSON]
store: StorePath,
codec_class: type[numcodecs.zarr3._NumcodecsCodec],
codec_config: dict[str, JSON],
):
data = np.linspace(0, 10, 256, dtype="float32").reshape((16, 16))

Expand All @@ -114,7 +125,7 @@ def test_generic_filter(
np.testing.assert_array_equal(data, a[:, :])


def test_generic_filter_bitround(store: Store):
def test_generic_filter_bitround(store: StorePath):
data = np.linspace(0, 1, 256, dtype="float32").reshape((16, 16))

with pytest.warns(UserWarning, match=EXPECTED_WARNING_STR):
Expand All @@ -132,7 +143,7 @@ def test_generic_filter_bitround(store: Store):
assert np.allclose(data, a[:, :], atol=0.1)


def test_generic_filter_quantize(store: Store):
def test_generic_filter_quantize(store: StorePath):
data = np.linspace(0, 10, 256, dtype="float32").reshape((16, 16))

with pytest.warns(UserWarning, match=EXPECTED_WARNING_STR):
Expand All @@ -150,7 +161,7 @@ def test_generic_filter_quantize(store: Store):
assert np.allclose(data, a[:, :], atol=0.001)


def test_generic_filter_packbits(store: Store):
def test_generic_filter_packbits(store: StorePath):
data = np.zeros((16, 16), dtype="bool")
data[0:4, :] = True

Expand Down Expand Up @@ -189,7 +200,7 @@ def test_generic_filter_packbits(store: Store):
numcodecs.zarr3.JenkinsLookup3,
],
)
def test_generic_checksum(store: Store, codec_class: type[numcodecs.zarr3._NumcodecsCodec]):
def test_generic_checksum(store: StorePath, codec_class: type[numcodecs.zarr3._NumcodecsCodec]):
data = np.linspace(0, 10, 256, dtype="float32").reshape((16, 16))

with pytest.warns(UserWarning, match=EXPECTED_WARNING_STR):
Expand All @@ -208,7 +219,7 @@ def test_generic_checksum(store: Store, codec_class: type[numcodecs.zarr3._Numco


@pytest.mark.parametrize("codec_class", [numcodecs.zarr3.PCodec, numcodecs.zarr3.ZFPY])
def test_generic_bytes_codec(store: Store, codec_class: type[numcodecs.zarr3._NumcodecsCodec]):
def test_generic_bytes_codec(store: StorePath, codec_class: type[numcodecs.zarr3._NumcodecsCodec]):
try:
codec_class()._codec # noqa: B018
except ValueError as e:
Expand Down
39 changes: 19 additions & 20 deletions numcodecs/zarr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

import asyncio
import math
from collections.abc import Callable
from dataclasses import dataclass, replace
from functools import cached_property, partial
from typing import Any, Self, TypeVar
Expand Down Expand Up @@ -76,7 +75,7 @@ class _NumcodecsCodec:
codec_name: str
codec_config: dict[str, JSON]

def __init__(self, **codec_config: dict[str, JSON]) -> None:
def __init__(self, **codec_config: JSON) -> None:
if not self.codec_name:
raise ValueError(
"The codec name needs to be supplied through the `codec_name` attribute."
Expand Down Expand Up @@ -106,7 +105,7 @@ def from_dict(cls, data: dict[str, JSON]) -> Self:
codec_config = _parse_codec_configuration(data)
return cls(**codec_config)

def to_dict(self) -> JSON:
def to_dict(self) -> dict[str, JSON]:
codec_config = self.codec_config.copy()
return {
"name": self.codec_name,
Expand All @@ -118,7 +117,7 @@ def compute_encoded_size(self, input_byte_length: int, chunk_spec: ArraySpec) ->


class _NumcodecsBytesBytesCodec(_NumcodecsCodec, BytesBytesCodec):
def __init__(self, **codec_config: dict[str, JSON]) -> None:
def __init__(self, **codec_config: JSON) -> None:
super().__init__(**codec_config)

async def _decode_single(self, chunk_bytes: Buffer, chunk_spec: ArraySpec) -> Buffer:
Expand All @@ -140,7 +139,7 @@ async def _encode_single(self, chunk_bytes: Buffer, chunk_spec: ArraySpec) -> Bu


class _NumcodecsArrayArrayCodec(_NumcodecsCodec, ArrayArrayCodec):
def __init__(self, **codec_config: dict[str, JSON]) -> None:
def __init__(self, **codec_config: JSON) -> None:
super().__init__(**codec_config)

async def _decode_single(self, chunk_array: NDBuffer, chunk_spec: ArraySpec) -> NDBuffer:
Expand All @@ -155,7 +154,7 @@ async def _encode_single(self, chunk_array: NDBuffer, chunk_spec: ArraySpec) ->


class _NumcodecsArrayBytesCodec(_NumcodecsCodec, ArrayBytesCodec):
def __init__(self, **codec_config: dict[str, JSON]) -> None:
def __init__(self, **codec_config: JSON) -> None:
super().__init__(**codec_config)

async def _decode_single(self, chunk_buffer: Buffer, chunk_spec: ArraySpec) -> NDBuffer:
Expand All @@ -179,7 +178,7 @@ def _add_docstring(cls: type[T], ref_class_name: str) -> type[T]:
return cls


def _add_docstring_wrapper(ref_class_name: str) -> Callable[[type[T]], type[T]]:
def _add_docstring_wrapper(ref_class_name: str) -> partial:
return partial(_add_docstring, ref_class_name=ref_class_name)


Expand All @@ -190,7 +189,7 @@ def _make_bytes_bytes_codec(codec_name: str, cls_name: str) -> type[_NumcodecsBy
class _Codec(_NumcodecsBytesBytesCodec):
codec_name = _codec_name

def __init__(self, **codec_config: dict[str, JSON]) -> None:
def __init__(self, **codec_config: JSON) -> None:
super().__init__(**codec_config)

_Codec.__name__ = cls_name
Expand All @@ -204,7 +203,7 @@ def _make_array_array_codec(codec_name: str, cls_name: str) -> type[_NumcodecsAr
class _Codec(_NumcodecsArrayArrayCodec):
codec_name = _codec_name

def __init__(self, **codec_config: dict[str, JSON]) -> None:
def __init__(self, **codec_config: JSON) -> None:
super().__init__(**codec_config)

_Codec.__name__ = cls_name
Expand All @@ -218,7 +217,7 @@ def _make_array_bytes_codec(codec_name: str, cls_name: str) -> type[_NumcodecsAr
class _Codec(_NumcodecsArrayBytesCodec):
codec_name = _codec_name

def __init__(self, **codec_config: dict[str, JSON]) -> None:
def __init__(self, **codec_config: JSON) -> None:
super().__init__(**codec_config)

_Codec.__name__ = cls_name
Expand All @@ -232,7 +231,7 @@ def _make_checksum_codec(codec_name: str, cls_name: str) -> type[_NumcodecsBytes
class _ChecksumCodec(_NumcodecsBytesBytesCodec):
codec_name = _codec_name

def __init__(self, **codec_config: dict[str, JSON]) -> None:
def __init__(self, **codec_config: JSON) -> None:
super().__init__(**codec_config)

def compute_encoded_size(self, input_byte_length: int, chunk_spec: ArraySpec) -> int:
Expand All @@ -256,10 +255,10 @@ def compute_encoded_size(self, input_byte_length: int, chunk_spec: ArraySpec) ->
class Shuffle(_NumcodecsBytesBytesCodec):
codec_name = f"{CODEC_PREFIX}shuffle"

def __init__(self, **codec_config: dict[str, JSON]) -> None:
def __init__(self, **codec_config: JSON) -> None:
super().__init__(**codec_config)

def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
def evolve_from_array_spec(self, array_spec: ArraySpec) -> Shuffle:
if array_spec.dtype.itemsize != self.codec_config.get("elementsize"):
return Shuffle(**{**self.codec_config, "elementsize": array_spec.dtype.itemsize})
return self # pragma: no cover
Expand All @@ -276,15 +275,15 @@ def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
class FixedScaleOffset(_NumcodecsArrayArrayCodec):
codec_name = f"{CODEC_PREFIX}fixedscaleoffset"

def __init__(self, **codec_config: dict[str, JSON]) -> None:
def __init__(self, **codec_config: JSON) -> None:
super().__init__(**codec_config)

def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
if astype := self.codec_config.get("astype"):
return replace(chunk_spec, dtype=np.dtype(astype))
return chunk_spec

def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
def evolve_from_array_spec(self, array_spec: ArraySpec) -> FixedScaleOffset:
if str(array_spec.dtype) != self.codec_config.get("dtype"):
return FixedScaleOffset(**{**self.codec_config, "dtype": str(array_spec.dtype)})
return self
Expand All @@ -294,10 +293,10 @@ def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
class Quantize(_NumcodecsArrayArrayCodec):
codec_name = f"{CODEC_PREFIX}quantize"

def __init__(self, **codec_config: dict[str, JSON]) -> None:
def __init__(self, **codec_config: JSON) -> None:
super().__init__(**codec_config)

def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
def evolve_from_array_spec(self, array_spec: ArraySpec) -> Quantize:
if str(array_spec.dtype) != self.codec_config.get("dtype"):
return Quantize(**{**self.codec_config, "dtype": str(array_spec.dtype)})
return self
Expand All @@ -307,7 +306,7 @@ def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
class PackBits(_NumcodecsArrayArrayCodec):
codec_name = f"{CODEC_PREFIX}packbits"

def __init__(self, **codec_config: dict[str, JSON]) -> None:
def __init__(self, **codec_config: JSON) -> None:
super().__init__(**codec_config)

def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
Expand All @@ -326,13 +325,13 @@ def validate(self, *, dtype: np.dtype[Any], **_kwargs) -> None:
class AsType(_NumcodecsArrayArrayCodec):
codec_name = f"{CODEC_PREFIX}astype"

def __init__(self, **codec_config: dict[str, JSON]) -> None:
def __init__(self, **codec_config: JSON) -> None:
super().__init__(**codec_config)

def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
return replace(chunk_spec, dtype=np.dtype(self.codec_config["encode_dtype"]))

def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
def evolve_from_array_spec(self, array_spec: ArraySpec) -> AsType:
decode_dtype = self.codec_config.get("decode_dtype")
if str(array_spec.dtype) != decode_dtype:
return AsType(**{**self.codec_config, "decode_dtype": str(array_spec.dtype)})
Expand Down

0 comments on commit ce83f63

Please sign in to comment.