diff --git a/src/zarr/core/dtype/__init__.py b/src/zarr/core/dtype/__init__.py new file mode 100644 index 000000000..7706534f7 --- /dev/null +++ b/src/zarr/core/dtype/__init__.py @@ -0,0 +1,7 @@ +from zarr.core.dtype.core import ( + ZarrDType +) + +__all__ = [ + "ZarrDType" +] diff --git a/src/zarr/core/dtype/core.py b/src/zarr/core/dtype/core.py new file mode 100644 index 000000000..a70f9f17f --- /dev/null +++ b/src/zarr/core/dtype/core.py @@ -0,0 +1,204 @@ +""" +# Overview + +This module provides a proof-of-concept standalone interface for managing dtypes in the zarr-python codebase. + +The `ZarrDType` class introduced in this module effectively acts as a replacement for `np.dtype` throughout the +zarr-python codebase. It attempts to encapsulate all relevant runtime information necessary for working with +dtypes in the context of the Zarr V3 specification (e.g. is this a core dtype or not, how many bytes and what +endianness is the dtype etc). By providing this abstraction, the module aims to: + +- Simplify dtype management within zarr-python +- Support runtime flexibility and custom extensions +- Remove unnecessary dependencies on the numpy API + +## Extensibility + +The module attempts to support user-driven extensions, allowing developers to introduce custom dtypes +without requiring immediate changes to zarr-python. Extensions can leverage the current entrypoint mechanism, +enabling integration of experimental features. Over time, widely adopted extensions may be formalized through +inclusion in zarr-python or standardized via a Zarr Enhancement Proposal (ZEP), but this is not essential. + +## Examples + +### Core `dtype` Registration + +The following example demonstrates how to register a built-in `dtype` in the core codebase: + +```python +from zarr.core.dtype import ZarrDType +from zarr.registry import register_v3dtype + +class Float16(ZarrDType): + zarr_spec_format = "3" + experimental = False + endianness = "little" + byte_count = 2 + to_numpy = np.dtype('float16') + +register_v3dtype(Float16) +``` + +### Entrypoint Extension + +The following example demonstrates how users can register a new `bfloat16` dtype for Zarr. +This approach adheres to the existing Zarr entrypoint pattern as much as possible, ensuring +consistency with other extensions. The code below would typically be part of a Python package +that specifies the entrypoints for the extension: + +```python +import ml_dtypes +from zarr.core.dtype import ZarrDType # User inherits from ZarrDType when creating their dtype + +class Bfloat16(ZarrDType): + zarr_spec_format = "3" + experimental = True + endianness = "little" + byte_count = 2 + to_numpy = np.dtype('bfloat16') # Enabled by importing ml_dtypes + configuration_v3 = { + "version": "example_value", + "author": "example_value", + "ml_dtypes_version": "example_value" + } +``` + +### dtype lookup + +The following examples demonstrate how to perform a lookup for the relevant ZarrDType, given +a string that matches the dtype Zarr specification ID, or a numpy dtype object: + +``` +from zarr.registry import get_v3dtype_class, get_v3dtype_class_from_numpy + +get_v3dtype_class('complex64') # returns little-endian Complex64 ZarrDType +get_v3dtype_class('not_registered_dtype') # ValueError + +get_v3dtype_class_from_numpy('>i2') # returns big-endian Int16 ZarrDType +get_v3dtype_class_from_numpy(np.dtype('float32')) # returns little-endian Float32 ZarrDType +get_v3dtype_class_from_numpy('i10') # ValueError +``` + +### String dtypes + +The following indicates one possibility for supporting variable-length strings. It is via the +entrypoint mechanism as in a previous example. The Apache Arrow specification does not currently +include a dtype for fixed-length strings (only for fixed-length bytes) and so I am using string +here to implicitly refer to a variable-length string data (there may be some subtleties with codecs +that means this needs to be refined further): + +```python +import numpy as np +from zarr.core.dtype import ZarrDType # User inherits from ZarrDType when creating their dtype + +try: + to_numpy = np.dtypes.StringDType() +except AttributeError: + to_numpy = np.dtypes.ObjectDType() + +class String(ZarrDType): + zarr_spec_format = "3" + experimental = True + endianness = 'little' + byte_count = None # None is defined to mean variable + to_numpy = to_numpy +``` + +### int4 dtype + +There is currently considerable interest in the AI community in 'quantising' models - storing +models at reduced precision, while minimising loss of information content. There are a number +of sub-byte dtypes that the community are using e.g. int4. Unfortunately numpy does not +currently have support for handling such sub-byte dtypes in an easy way. However, they can +still be held in a numpy array and then passed (in a zero-copy way) to something like pytorch +which can handle appropriately: + +```python +import numpy as np +from zarr.core.dtype import ZarrDType # User inherits from ZarrDType when creating their dtype + +class Int4(ZarrDType): + zarr_spec_format = "3" + experimental = True + endianness = 'little' + byte_count = 1 # this is ugly, but I could change this from byte_count to bit_count if there was consensus + to_numpy = np.dtype('B') # could also be np.dtype('V1'), but this would prevent bit-twiddling + configuration_v3 = { + "version": "example_value", + "author": "example_value", + } +``` +""" + +from __future__ import annotations + +from typing import Any, Literal + +import numpy as np + + +# perhaps over-complicating, but I don't want to allow the attributes to be patched +class FrozenClassVariables(type): + def __setattr__(cls, attr, value): + if hasattr(cls, attr): + raise ValueError( + f"Attribute {attr} on ZarrDType class can not be changed once set." + ) + + +class ZarrDType(metaclass=FrozenClassVariables): + + zarr_spec_format: Literal["2", "3"] # the version of the zarr spec used + experimental: bool # is this in the core spec or not + endianness: Literal[ + "big", "little", None + ] # None indicates not defined i.e. single byte or byte strings + byte_count: int | None # None indicates variable count + to_numpy: np.dtype[ + Any + ] # may involve installing a a numpy extension e.g. ml_dtypes; + + configuration_v3: ( + dict | None + ) # TODO: understand better how this is recommended by the spec + + _zarr_spec_identifier: str # implementation detail used to map to core spec + + def __init_subclass__( # enforces all required fields are set and basic sanity checks + cls, + **kwargs, + ) -> None: + + required_attrs = [ + "zarr_spec_format", + "experimental", + "endianness", + "byte_count", + "to_numpy", + ] + for attr in required_attrs: + if not hasattr(cls, attr): + raise ValueError(f"{attr} is a required attribute for a Zarr dtype.") + + if not hasattr(cls, "configuration_v3"): + cls.configuration_v3 = None + + cls._zarr_spec_identifier = ( + "big_" + cls.__qualname__.lower() + if cls.endianness == "big" + else cls.__qualname__.lower() + ) # how this dtype is identified in core spec; convention is prefix with big_ for big-endian + + cls._validate() # sanity check on basic requirements + + super().__init_subclass__(**kwargs) + + # TODO: add further checks + @classmethod + def _validate(cls): + + if cls.byte_count is not None and cls.byte_count <= 0: + raise ValueError("byte_count must be a positive integer.") + + if cls.byte_count == 1 and cls.endianness is not None: + raise ValueError("Endianness must be None for single-byte types.") diff --git a/src/zarr/registry.py b/src/zarr/registry.py index 704db3f70..392be0f46 100644 --- a/src/zarr/registry.py +++ b/src/zarr/registry.py @@ -5,11 +5,15 @@ from importlib.metadata import entry_points as get_entry_points from typing import TYPE_CHECKING, Any, Generic, TypeVar +import numpy as np + from zarr.core.config import BadConfigError, config if TYPE_CHECKING: from importlib.metadata import EntryPoint + import numpy.typing as npt + from zarr.abc.codec import ( ArrayArrayCodec, ArrayBytesCodec, @@ -19,6 +23,7 @@ ) from zarr.core.buffer import Buffer, NDBuffer from zarr.core.common import JSON + from zarr.core.dtype import ZarrDType __all__ = [ "Registry", @@ -26,10 +31,14 @@ "get_codec_class", "get_ndbuffer_class", "get_pipeline_class", + "get_v3dtype_class", + "get_v2dtype_class", "register_buffer", "register_codec", "register_ndbuffer", "register_pipeline", + "register_v3dtype", + "register_v2dtype", ] T = TypeVar("T") @@ -42,28 +51,52 @@ def __init__(self) -> None: def lazy_load(self) -> None: for e in self.lazy_load_list: - self.register(e.load()) + cls = e.load() + + if (hasattr(cls, "_zarr_spec_identifier") and hasattr(cls, "to_numpy")): + for val in self.values(): + # allow the V and B dtypes to be overloaded - something of an escape hatch. + if val.to_numpy == cls.to_numpy and cls.to_numpy.kind not in ["V", "B"]: + raise ValueError( + f"numpy dtype {cls.to_numpy} already exists for Zarr dtype {val}." + ) + + self.register(cls, cls._zarr_spec_identifier) + else: + self.register(cls) + self.lazy_load_list.clear() - def register(self, cls: type[T]) -> None: - self[fully_qualified_name(cls)] = cls + def register(self, cls: type[T], class_registry_key: str | None = None) -> None: + if class_registry_key is None: + self[fully_qualified_name(cls)] = cls + else: + if class_registry_key in self: + raise ValueError( + f"{class_registry_key} already exists in the registry. Please try a different name." + ) # not great, but don't want to have the possibility of clobbering existing core dtypes + self[class_registry_key] = cls __codec_registries: dict[str, Registry[Codec]] = defaultdict(Registry) __pipeline_registry: Registry[CodecPipeline] = Registry() __buffer_registry: Registry[Buffer] = Registry() __ndbuffer_registry: Registry[NDBuffer] = Registry() +__v3_dtype_registry: Registry[ZarrDType] = Registry() +__v2_dtype_registry: Registry[ZarrDType] = Registry() """ The registry module is responsible for managing implementations of codecs, pipelines, buffers and ndbuffers and collecting them from entrypoints. The implementation used is determined by the config. + +The registry module is also responsible for managing dtypes. """ def _collect_entrypoints() -> list[Registry[Any]]: """ - Collects codecs, pipelines, buffers and ndbuffers from entrypoints. + Collects codecs, pipelines, dtypes, buffers and ndbuffers from entrypoints. Entry points can either be single items or groups of items. Allowed syntax for entry_points.txt is e.g. @@ -83,10 +116,26 @@ def _collect_entrypoints() -> list[Registry[Any]]: entry_points = get_entry_points() __buffer_registry.lazy_load_list.extend(entry_points.select(group="zarr.buffer")) - __buffer_registry.lazy_load_list.extend(entry_points.select(group="zarr", name="buffer")) - __ndbuffer_registry.lazy_load_list.extend(entry_points.select(group="zarr.ndbuffer")) - __ndbuffer_registry.lazy_load_list.extend(entry_points.select(group="zarr", name="ndbuffer")) - __pipeline_registry.lazy_load_list.extend(entry_points.select(group="zarr.codec_pipeline")) + __buffer_registry.lazy_load_list.extend( + entry_points.select(group="zarr", name="buffer") + ) + __ndbuffer_registry.lazy_load_list.extend( + entry_points.select(group="zarr.ndbuffer") + ) + __ndbuffer_registry.lazy_load_list.extend( + entry_points.select(group="zarr", name="ndbuffer") + ) + __v3_dtype_registry.lazy_load_list.extend(entry_points.select(group="zarr.v3dtype")) + __v3_dtype_registry.lazy_load_list.extend( + entry_points.select(group="zarr", name="v3dtype") + ) + __v2_dtype_registry.lazy_load_list.extend(entry_points.select(group="zarr.v2dtype")) + __v2_dtype_registry.lazy_load_list.extend( + entry_points.select(group="zarr", name="v2dtype") + ) + __pipeline_registry.lazy_load_list.extend( + entry_points.select(group="zarr.codec_pipeline") + ) __pipeline_registry.lazy_load_list.extend( entry_points.select(group="zarr", name="codec_pipeline") ) @@ -95,7 +144,9 @@ def _collect_entrypoints() -> list[Registry[Any]]: for group in entry_points.groups: if group.startswith("zarr.codecs."): codec_name = group.split(".")[2] - __codec_registries[codec_name].lazy_load_list.extend(entry_points.select(group=group)) + __codec_registries[codec_name].lazy_load_list.extend( + entry_points.select(group=group) + ) return [ *__codec_registries.values(), __pipeline_registry, @@ -131,6 +182,16 @@ def register_buffer(cls: type[Buffer]) -> None: __buffer_registry.register(cls) +def register_v3dtype(cls: type[ZarrDType]) -> None: + assert cls.zarr_spec_format == "3" + __v3_dtype_registry.register(cls, class_registry_key=cls._zarr_spec_identifier) + + +def register_v2dtype(cls: type[ZarrDType]) -> None: + assert cls.zarr_spec_format == "2" + __v2_dtype_registry.register(cls, class_registry_key=cls._zarr_spec_identifier) + + def get_codec_class(key: str, reload_config: bool = False) -> type[Codec]: if reload_config: _reload_config() @@ -148,7 +209,8 @@ def get_codec_class(key: str, reload_config: bool = False) -> type[Codec]: if len(codec_classes) == 1: return next(iter(codec_classes.values())) warnings.warn( - f"Codec '{key}' not configured in config. Selecting any implementation.", stacklevel=2 + f"Codec '{key}' not configured in config. Selecting any implementation.", + stacklevel=2, ) return list(codec_classes.values())[-1] selected_codec_cls = codec_classes[config_entry] @@ -266,4 +328,50 @@ def get_ndbuffer_class(reload_config: bool = False) -> type[NDBuffer]: ) +# TODO: merge the get_vXdtype_class_ functions +# these can be used instead of the various parse_X functions (hopefully) +def get_v3dtype_class(dtype: str) -> type[ZarrDType]: + __v3_dtype_registry.lazy_load() + v3dtype_class = __v3_dtype_registry.get(dtype) + if v3dtype_class: + return v3dtype_class + raise ValueError( + f"ZarrDType class '{dtype}' not found in registered buffers: {list(__v3_dtype_registry)}." + ) + + +def get_v3dtype_class_from_numpy(dtype: npt.DTypeLike) -> type[ZarrDType]: + __v3_dtype_registry.lazy_load() + + dtype = np.dtype(dtype) + for val in __v3_dtype_registry.values(): + if dtype == val.to_numpy: + return val + raise ValueError( + f"numpy dtype '{dtype}' does not have a corresponding Zarr dtype in: {list(__v3_dtype_registry)}." + ) + + +def get_v2dtype_class(dtype: str) -> type[ZarrDType]: + __v2_dtype_registry.lazy_load() + v2dtype_class = __v2_dtype_registry.get(dtype) + if v2dtype_class: + return v2dtype_class + raise ValueError( + f"ZarrDType class '{dtype}' not found in registered buffers: {list(__v2_dtype_registry)}." + ) + + +def get_v2dtype_class_from_numpy(dtype: npt.DTypeLike) -> type[ZarrDType]: + __v2_dtype_registry.lazy_load() + + dtype = np.dtype(dtype) + for val in __v2_dtype_registry.values(): + if dtype == val.to_numpy: + return val + raise ValueError( + f"numpy dtype '{dtype}' does not have a corresponding Zarr dtype in: {list(__v2_dtype_registry)}." + ) + + _collect_entrypoints()