From 483681b76a14a7d1927815ae1def7c4efa9f3282 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Wed, 9 Oct 2024 09:30:38 -0500 Subject: [PATCH 1/6] Special case object dtype Closes https://github.com/zarr-developers/zarr-python/issues/2315 --- src/zarr/core/array.py | 4 +++- src/zarr/core/common.py | 9 +++++++++ src/zarr/core/metadata/v2.py | 7 +------ src/zarr/core/metadata/v3.py | 4 ++++ tests/v3/test_array.py | 9 ++++++++- 5 files changed, 25 insertions(+), 8 deletions(-) diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 9f5591ce1e..12a94c64ae 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -35,6 +35,7 @@ ShapeLike, ZarrFormat, concurrent_map, + parse_dtype, parse_shapelike, product, ) @@ -226,7 +227,8 @@ async def create( if chunks is not None and chunk_shape is not None: raise ValueError("Only one of chunk_shape or chunks can be provided.") - dtype = np.dtype(dtype) + dtype = parse_dtype(dtype) + # dtype = np.dtype(dtype) if chunks: _chunks = normalize_chunks(chunks, shape, dtype.itemsize) else: diff --git a/src/zarr/core/common.py b/src/zarr/core/common.py index 80c743cc90..93485d1b73 100644 --- a/src/zarr/core/common.py +++ b/src/zarr/core/common.py @@ -16,6 +16,8 @@ overload, ) +import numpy as np + if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Iterator @@ -162,3 +164,10 @@ def parse_order(data: Any) -> Literal["C", "F"]: if data in ("C", "F"): return cast(Literal["C", "F"], data) raise ValueError(f"Expected one of ('C', 'F'), got {data} instead.") + + +def parse_dtype(dtype: Any) -> np.dtype[Any]: + if dtype is str or dtype == "str": + # special case as object + return np.dtype("object") + return np.dtype(dtype) diff --git a/src/zarr/core/metadata/v2.py b/src/zarr/core/metadata/v2.py index ec44673f9d..5ac4f9b361 100644 --- a/src/zarr/core/metadata/v2.py +++ b/src/zarr/core/metadata/v2.py @@ -22,7 +22,7 @@ from zarr.core.array_spec import ArraySpec from zarr.core.chunk_grids import RegularChunkGrid from zarr.core.chunk_key_encodings import parse_separator -from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON, parse_shapelike +from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON, parse_dtype, parse_shapelike from zarr.core.config import config, parse_indexing_order from zarr.core.metadata.common import ArrayMetadata, parse_attributes @@ -201,11 +201,6 @@ def update_attributes(self, attributes: dict[str, JSON]) -> Self: return replace(self, attributes=attributes) -def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]: - # todo: real validation - return np.dtype(data) - - def parse_zarr_format(data: object) -> Literal[2]: if data == 2: return 2 diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 47c6106bfe..d77646249d 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -504,6 +504,7 @@ class DataType(Enum): complex128 = "complex128" string = "string" bytes = "bytes" + object = "object" @property def byte_count(self) -> None | int: @@ -549,6 +550,7 @@ def to_numpy_shortname(self) -> str: DataType.float64: "f8", DataType.complex64: "c8", DataType.complex128: "c16", + DataType.object: "object", } return data_type_to_numpy[self] @@ -572,6 +574,8 @@ def from_numpy(cls, dtype: np.dtype[Any]) -> DataType: return DataType.string elif dtype.kind == "S": return DataType.bytes + elif dtype.kind == "O": + return DataType.object dtype_to_data_type = { "|b1": "bool", "bool": "bool", diff --git a/tests/v3/test_array.py b/tests/v3/test_array.py index 04adb2a224..fa15cd08d7 100644 --- a/tests/v3/test_array.py +++ b/tests/v3/test_array.py @@ -1,6 +1,6 @@ import pickle from itertools import accumulate -from typing import Literal +from typing import Any, Literal import numpy as np import pytest @@ -406,3 +406,10 @@ def test_vlen_errors() -> None: dtype=" None: + arr = zarr.create(shape=10, dtype=dtype, zarr_format=zarr_format) + assert arr.dtype.kind == "O" From 7e76e9e392c07fbf854a2a1892211bba9e86566b Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Wed, 9 Oct 2024 10:08:13 -0500 Subject: [PATCH 2/6] fixup --- src/zarr/core/array.py | 2 +- src/zarr/core/common.py | 11 ++++++++--- src/zarr/core/metadata/v2.py | 4 ++-- src/zarr/core/metadata/v3.py | 4 ---- tests/v3/test_array.py | 9 +-------- tests/v3/test_codecs/test_vlen.py | 2 +- tests/v3/test_v2.py | 8 ++++++++ 7 files changed, 21 insertions(+), 19 deletions(-) diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 12a94c64ae..a80e15e896 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -227,7 +227,7 @@ async def create( if chunks is not None and chunk_shape is not None: raise ValueError("Only one of chunk_shape or chunks can be provided.") - dtype = parse_dtype(dtype) + dtype = parse_dtype(dtype, zarr_format) # dtype = np.dtype(dtype) if chunks: _chunks = normalize_chunks(chunks, shape, dtype.itemsize) diff --git a/src/zarr/core/common.py b/src/zarr/core/common.py index 93485d1b73..d74b07ebce 100644 --- a/src/zarr/core/common.py +++ b/src/zarr/core/common.py @@ -18,6 +18,8 @@ import numpy as np +from zarr.core.strings import _STRING_DTYPE + if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Iterator @@ -166,8 +168,11 @@ def parse_order(data: Any) -> Literal["C", "F"]: raise ValueError(f"Expected one of ('C', 'F'), got {data} instead.") -def parse_dtype(dtype: Any) -> np.dtype[Any]: +def parse_dtype(dtype: Any, zarr_format: ZarrFormat) -> np.dtype[Any]: if dtype is str or dtype == "str": - # special case as object - return np.dtype("object") + if zarr_format == 2: + # special case as object + return np.dtype("object") + else: + return _STRING_DTYPE return np.dtype(dtype) diff --git a/src/zarr/core/metadata/v2.py b/src/zarr/core/metadata/v2.py index 5ac4f9b361..ddbf0aa94d 100644 --- a/src/zarr/core/metadata/v2.py +++ b/src/zarr/core/metadata/v2.py @@ -57,7 +57,7 @@ def __init__( Metadata for a Zarr version 2 array. """ shape_parsed = parse_shapelike(shape) - data_type_parsed = parse_dtype(dtype) + data_type_parsed = parse_dtype(dtype, zarr_format=2) chunks_parsed = parse_shapelike(chunks) compressor_parsed = parse_compressor(compressor) order_parsed = parse_indexing_order(order) @@ -141,7 +141,7 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata: _data = data.copy() # check that the zarr_format attribute is correct _ = parse_zarr_format(_data.pop("zarr_format")) - dtype = parse_dtype(_data["dtype"]) + dtype = parse_dtype(_data["dtype"], zarr_format=2) if dtype.kind in "SV": fill_value_encoded = _data.get("fill_value") diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index d77646249d..47c6106bfe 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -504,7 +504,6 @@ class DataType(Enum): complex128 = "complex128" string = "string" bytes = "bytes" - object = "object" @property def byte_count(self) -> None | int: @@ -550,7 +549,6 @@ def to_numpy_shortname(self) -> str: DataType.float64: "f8", DataType.complex64: "c8", DataType.complex128: "c16", - DataType.object: "object", } return data_type_to_numpy[self] @@ -574,8 +572,6 @@ def from_numpy(cls, dtype: np.dtype[Any]) -> DataType: return DataType.string elif dtype.kind == "S": return DataType.bytes - elif dtype.kind == "O": - return DataType.object dtype_to_data_type = { "|b1": "bool", "bool": "bool", diff --git a/tests/v3/test_array.py b/tests/v3/test_array.py index fa15cd08d7..04adb2a224 100644 --- a/tests/v3/test_array.py +++ b/tests/v3/test_array.py @@ -1,6 +1,6 @@ import pickle from itertools import accumulate -from typing import Any, Literal +from typing import Literal import numpy as np import pytest @@ -406,10 +406,3 @@ def test_vlen_errors() -> None: dtype=" None: - arr = zarr.create(shape=10, dtype=dtype, zarr_format=zarr_format) - assert arr.dtype.kind == "O" diff --git a/tests/v3/test_codecs/test_vlen.py b/tests/v3/test_codecs/test_vlen.py index c6f587931a..aaea5dab83 100644 --- a/tests/v3/test_codecs/test_vlen.py +++ b/tests/v3/test_codecs/test_vlen.py @@ -11,7 +11,7 @@ from zarr.core.strings import _NUMPY_SUPPORTS_VLEN_STRING from zarr.storage.common import StorePath -numpy_str_dtypes: list[type | None] = [None, str, np.dtypes.StrDType] +numpy_str_dtypes: list[type | str | None] = [None, str, "str", np.dtypes.StrDType] expected_zarr_string_dtype: np.dtype[Any] if _NUMPY_SUPPORTS_VLEN_STRING: numpy_str_dtypes.append(np.dtypes.StringDType) diff --git a/tests/v3/test_v2.py b/tests/v3/test_v2.py index f488782d78..60e59e49bf 100644 --- a/tests/v3/test_v2.py +++ b/tests/v3/test_v2.py @@ -1,5 +1,6 @@ import json from collections.abc import Iterator +from typing import Any import numpy as np import pytest @@ -84,3 +85,10 @@ async def test_v2_encode_decode(dtype): data = zarr.open_array(store=store, path="foo")[:] expected = np.full((3,), b"X", dtype=dtype) np.testing.assert_equal(data, expected) + + +@pytest.mark.parametrize("dtype", [str, "str"]) +async def test_create_dtype_str(dtype: Any) -> None: + arr = zarr.create(shape=10, dtype=dtype, zarr_format=2) + assert arr.dtype.kind == "O" + assert arr.metadata.to_dict()["dtype"] == "|O" From 2db00ff2486bdeb551b43f47af0ad52386a1aa37 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Wed, 9 Oct 2024 16:18:44 -0500 Subject: [PATCH 3/6] remove dead code --- src/zarr/core/array.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index cf455fe5b4..2802bff76a 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -222,13 +222,12 @@ async def create( ) -> AsyncArray: store_path = await make_store_path(store) + dtype = parse_dtype(dtype, zarr_format) shape = parse_shapelike(shape) if chunks is not None and chunk_shape is not None: raise ValueError("Only one of chunk_shape or chunks can be provided.") - dtype = parse_dtype(dtype, zarr_format) - # dtype = np.dtype(dtype) if chunks: _chunks = normalize_chunks(chunks, shape, dtype.itemsize) else: From 4b0a39ebe1729a4930ce16c24db72d0bc8ff6014 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Thu, 10 Oct 2024 10:24:40 -0500 Subject: [PATCH 4/6] fixup --- src/zarr/core/metadata/v2.py | 12 ++++++++---- tests/v3/test_v2.py | 9 +++++++-- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/zarr/core/metadata/v2.py b/src/zarr/core/metadata/v2.py index 2bae6f2db1..3b9bbbdf5a 100644 --- a/src/zarr/core/metadata/v2.py +++ b/src/zarr/core/metadata/v2.py @@ -22,7 +22,7 @@ from zarr.core.array_spec import ArraySpec from zarr.core.chunk_grids import RegularChunkGrid from zarr.core.chunk_key_encodings import parse_separator -from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON, parse_dtype, parse_shapelike +from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON, parse_shapelike from zarr.core.config import config, parse_indexing_order from zarr.core.metadata.common import ArrayMetadata, parse_attributes @@ -57,7 +57,7 @@ def __init__( Metadata for a Zarr version 2 array. """ shape_parsed = parse_shapelike(shape) - data_type_parsed = parse_dtype(dtype, zarr_format=2) + data_type_parsed = parse_dtype(dtype) chunks_parsed = parse_shapelike(chunks) compressor_parsed = parse_compressor(compressor) order_parsed = parse_indexing_order(order) @@ -141,7 +141,7 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata: _data = data.copy() # check that the zarr_format attribute is correct _ = parse_zarr_format(_data.pop("zarr_format")) - dtype = parse_dtype(_data["dtype"], zarr_format=2) + dtype = parse_dtype(_data["dtype"]) if dtype.kind in "SV": fill_value_encoded = _data.get("fill_value") @@ -201,6 +201,10 @@ def update_attributes(self, attributes: dict[str, JSON]) -> Self: return replace(self, attributes=attributes) +def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]: + return np.dtype(data) + + def parse_zarr_format(data: object) -> Literal[2]: if data == 2: return 2 @@ -312,7 +316,7 @@ def _default_fill_value(dtype: np.dtype[Any]) -> Any: """ if dtype.kind == "S": return b"" - elif dtype.kind == "U": + elif dtype.kind in "UO": return "" else: return dtype.type(0) diff --git a/tests/v3/test_v2.py b/tests/v3/test_v2.py index 55142c6894..9ded7fc481 100644 --- a/tests/v3/test_v2.py +++ b/tests/v3/test_v2.py @@ -44,7 +44,7 @@ def test_simple(store: StorePath) -> None: ("float64", 0.0), ("|S1", b""), ("|U1", ""), - ("object", 0), + ("object", ""), (str, ""), ], ) @@ -53,7 +53,12 @@ def test_implicit_fill_value(store: StorePath, dtype: str, fill_value: Any) -> N assert arr.metadata.fill_value is None assert arr.metadata.to_dict()["fill_value"] is None result = arr[:] - expected = np.full(arr.shape, fill_value, dtype=dtype) + if dtype is str: + # special case + numpy_dtype = np.dtype(object) + else: + numpy_dtype = np.dtype(dtype) + expected = np.full(arr.shape, fill_value, dtype=numpy_dtype) np.testing.assert_array_equal(result, expected) From d8f24a89be7c318326ef981c04b7e3ee561444b7 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Thu, 10 Oct 2024 11:32:19 -0500 Subject: [PATCH 5/6] automatically add filter --- src/zarr/core/array.py | 18 +++++++++++++----- tests/v3/test_v2.py | 2 ++ 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index adf33f6cbf..971d79fdfa 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -223,16 +223,16 @@ async def create( ) -> AsyncArray: store_path = await make_store_path(store) - dtype = parse_dtype(dtype, zarr_format) + dtype_parsed = parse_dtype(dtype, zarr_format) shape = parse_shapelike(shape) if chunks is not None and chunk_shape is not None: raise ValueError("Only one of chunk_shape or chunks can be provided.") if chunks: - _chunks = normalize_chunks(chunks, shape, dtype.itemsize) + _chunks = normalize_chunks(chunks, shape, dtype_parsed.itemsize) else: - _chunks = normalize_chunks(chunk_shape, shape, dtype.itemsize) + _chunks = normalize_chunks(chunk_shape, shape, dtype_parsed.itemsize) if zarr_format == 3: if dimension_separator is not None: @@ -254,7 +254,7 @@ async def create( result = await cls._create_v3( store_path, shape=shape, - dtype=dtype, + dtype=dtype_parsed, chunk_shape=_chunks, fill_value=fill_value, chunk_key_encoding=chunk_key_encoding, @@ -264,6 +264,14 @@ async def create( exists_ok=exists_ok, ) elif zarr_format == 2: + if dtype is str or dtype == "str": + # another special case: zarr v2 added the vlen-utf8 codec + vlen_codec: dict[str, JSON] = {"id": "vlen-utf8"} + if filters and not any(x["id"] == "vlen-utf8" for x in filters): + filters = list(filters) + [vlen_codec] + else: + filters = [vlen_codec] + if codecs is not None: raise ValueError( "codecs cannot be used for arrays with version 2. Use filters and compressor instead." @@ -277,7 +285,7 @@ async def create( result = await cls._create_v2( store_path, shape=shape, - dtype=dtype, + dtype=dtype_parsed, chunks=_chunks, dimension_separator=dimension_separator, fill_value=fill_value, diff --git a/tests/v3/test_v2.py b/tests/v3/test_v2.py index 9ded7fc481..0628c170dc 100644 --- a/tests/v3/test_v2.py +++ b/tests/v3/test_v2.py @@ -2,6 +2,7 @@ from collections.abc import Iterator from typing import Any +import numcodecs.vlen import numpy as np import pytest from numcodecs import Delta @@ -118,3 +119,4 @@ async def test_create_dtype_str(dtype: Any) -> None: arr = zarr.create(shape=10, dtype=dtype, zarr_format=2) assert arr.dtype.kind == "O" assert arr.metadata.to_dict()["dtype"] == "|O" + assert arr.metadata.filters == (numcodecs.vlen.VLenUTF8(),) From 509a5c1cb959a68043e39cc37aa50bdadcf72eb9 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Thu, 10 Oct 2024 14:53:50 -0500 Subject: [PATCH 6/6] maybe fixed --- src/zarr/codecs/_v2.py | 2 +- tests/v3/test_v2.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/zarr/codecs/_v2.py b/src/zarr/codecs/_v2.py index cc6129e604..dd079b7b76 100644 --- a/src/zarr/codecs/_v2.py +++ b/src/zarr/codecs/_v2.py @@ -36,7 +36,7 @@ async def _decode_single( chunk_numpy_array = ensure_ndarray(chunk_bytes.as_array_like()) # ensure correct dtype - if str(chunk_numpy_array.dtype) != chunk_spec.dtype: + if str(chunk_numpy_array.dtype) != chunk_spec.dtype and not chunk_spec.dtype.hasobject: chunk_numpy_array = chunk_numpy_array.view(chunk_spec.dtype) return get_ndbuffer_class().from_numpy_array(chunk_numpy_array) diff --git a/tests/v3/test_v2.py b/tests/v3/test_v2.py index 0628c170dc..729ed0533f 100644 --- a/tests/v3/test_v2.py +++ b/tests/v3/test_v2.py @@ -116,7 +116,10 @@ async def test_v2_encode_decode(dtype): @pytest.mark.parametrize("dtype", [str, "str"]) async def test_create_dtype_str(dtype: Any) -> None: - arr = zarr.create(shape=10, dtype=dtype, zarr_format=2) + arr = zarr.create(shape=3, dtype=dtype, zarr_format=2) assert arr.dtype.kind == "O" assert arr.metadata.to_dict()["dtype"] == "|O" assert arr.metadata.filters == (numcodecs.vlen.VLenUTF8(),) + arr[:] = ["a", "bb", "ccc"] + result = arr[:] + np.testing.assert_array_equal(result, np.array(["a", "bb", "ccc"], dtype="object"))