From 7e76e9e392c07fbf854a2a1892211bba9e86566b Mon Sep 17 00:00:00 2001 From: Tom Augspurger <tom.w.augspurger@gmail.com> Date: Wed, 9 Oct 2024 10:08:13 -0500 Subject: [PATCH] fixup --- src/zarr/core/array.py | 2 +- src/zarr/core/common.py | 11 ++++++++--- src/zarr/core/metadata/v2.py | 4 ++-- src/zarr/core/metadata/v3.py | 4 ---- tests/v3/test_array.py | 9 +-------- tests/v3/test_codecs/test_vlen.py | 2 +- tests/v3/test_v2.py | 8 ++++++++ 7 files changed, 21 insertions(+), 19 deletions(-) diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 12a94c64ae..a80e15e896 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -227,7 +227,7 @@ async def create( if chunks is not None and chunk_shape is not None: raise ValueError("Only one of chunk_shape or chunks can be provided.") - dtype = parse_dtype(dtype) + dtype = parse_dtype(dtype, zarr_format) # dtype = np.dtype(dtype) if chunks: _chunks = normalize_chunks(chunks, shape, dtype.itemsize) diff --git a/src/zarr/core/common.py b/src/zarr/core/common.py index 93485d1b73..d74b07ebce 100644 --- a/src/zarr/core/common.py +++ b/src/zarr/core/common.py @@ -18,6 +18,8 @@ import numpy as np +from zarr.core.strings import _STRING_DTYPE + if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Iterator @@ -166,8 +168,11 @@ def parse_order(data: Any) -> Literal["C", "F"]: raise ValueError(f"Expected one of ('C', 'F'), got {data} instead.") -def parse_dtype(dtype: Any) -> np.dtype[Any]: +def parse_dtype(dtype: Any, zarr_format: ZarrFormat) -> np.dtype[Any]: if dtype is str or dtype == "str": - # special case as object - return np.dtype("object") + if zarr_format == 2: + # special case as object + return np.dtype("object") + else: + return _STRING_DTYPE return np.dtype(dtype) diff --git a/src/zarr/core/metadata/v2.py b/src/zarr/core/metadata/v2.py index 5ac4f9b361..ddbf0aa94d 100644 --- a/src/zarr/core/metadata/v2.py +++ b/src/zarr/core/metadata/v2.py @@ -57,7 +57,7 @@ def __init__( Metadata for a Zarr version 2 array. """ shape_parsed = parse_shapelike(shape) - data_type_parsed = parse_dtype(dtype) + data_type_parsed = parse_dtype(dtype, zarr_format=2) chunks_parsed = parse_shapelike(chunks) compressor_parsed = parse_compressor(compressor) order_parsed = parse_indexing_order(order) @@ -141,7 +141,7 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata: _data = data.copy() # check that the zarr_format attribute is correct _ = parse_zarr_format(_data.pop("zarr_format")) - dtype = parse_dtype(_data["dtype"]) + dtype = parse_dtype(_data["dtype"], zarr_format=2) if dtype.kind in "SV": fill_value_encoded = _data.get("fill_value") diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index d77646249d..47c6106bfe 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -504,7 +504,6 @@ class DataType(Enum): complex128 = "complex128" string = "string" bytes = "bytes" - object = "object" @property def byte_count(self) -> None | int: @@ -550,7 +549,6 @@ def to_numpy_shortname(self) -> str: DataType.float64: "f8", DataType.complex64: "c8", DataType.complex128: "c16", - DataType.object: "object", } return data_type_to_numpy[self] @@ -574,8 +572,6 @@ def from_numpy(cls, dtype: np.dtype[Any]) -> DataType: return DataType.string elif dtype.kind == "S": return DataType.bytes - elif dtype.kind == "O": - return DataType.object dtype_to_data_type = { "|b1": "bool", "bool": "bool", diff --git a/tests/v3/test_array.py b/tests/v3/test_array.py index fa15cd08d7..04adb2a224 100644 --- a/tests/v3/test_array.py +++ b/tests/v3/test_array.py @@ -1,6 +1,6 @@ import pickle from itertools import accumulate -from typing import Any, Literal +from typing import Literal import numpy as np import pytest @@ -406,10 +406,3 @@ def test_vlen_errors() -> None: dtype="<U4", codecs=[BytesCodec(), VLenBytesCodec()], ) - - -@pytest.mark.parametrize("zarr_format", [2, 3, None]) -@pytest.mark.parametrize("dtype", [str, "str"]) -def test_create_dtype_str(dtype: Any, zarr_format: ZarrFormat | None) -> None: - arr = zarr.create(shape=10, dtype=dtype, zarr_format=zarr_format) - assert arr.dtype.kind == "O" diff --git a/tests/v3/test_codecs/test_vlen.py b/tests/v3/test_codecs/test_vlen.py index c6f587931a..aaea5dab83 100644 --- a/tests/v3/test_codecs/test_vlen.py +++ b/tests/v3/test_codecs/test_vlen.py @@ -11,7 +11,7 @@ from zarr.core.strings import _NUMPY_SUPPORTS_VLEN_STRING from zarr.storage.common import StorePath -numpy_str_dtypes: list[type | None] = [None, str, np.dtypes.StrDType] +numpy_str_dtypes: list[type | str | None] = [None, str, "str", np.dtypes.StrDType] expected_zarr_string_dtype: np.dtype[Any] if _NUMPY_SUPPORTS_VLEN_STRING: numpy_str_dtypes.append(np.dtypes.StringDType) diff --git a/tests/v3/test_v2.py b/tests/v3/test_v2.py index f488782d78..60e59e49bf 100644 --- a/tests/v3/test_v2.py +++ b/tests/v3/test_v2.py @@ -1,5 +1,6 @@ import json from collections.abc import Iterator +from typing import Any import numpy as np import pytest @@ -84,3 +85,10 @@ async def test_v2_encode_decode(dtype): data = zarr.open_array(store=store, path="foo")[:] expected = np.full((3,), b"X", dtype=dtype) np.testing.assert_equal(data, expected) + + +@pytest.mark.parametrize("dtype", [str, "str"]) +async def test_create_dtype_str(dtype: Any) -> None: + arr = zarr.create(shape=10, dtype=dtype, zarr_format=2) + assert arr.dtype.kind == "O" + assert arr.metadata.to_dict()["dtype"] == "|O"