Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Special case str dtype in array creation #2323

Merged
3 changes: 2 additions & 1 deletion src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
ShapeLike,
ZarrFormat,
concurrent_map,
parse_dtype,
parse_shapelike,
product,
)
Expand Down Expand Up @@ -222,12 +223,12 @@ async def create(
) -> AsyncArray:
store_path = await make_store_path(store)

dtype = parse_dtype(dtype, zarr_format)
shape = parse_shapelike(shape)

if chunks is not None and chunk_shape is not None:
raise ValueError("Only one of chunk_shape or chunks can be provided.")

dtype = np.dtype(dtype)
if chunks:
_chunks = normalize_chunks(chunks, shape, dtype.itemsize)
else:
Expand Down
14 changes: 14 additions & 0 deletions src/zarr/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
overload,
)

import numpy as np

from zarr.core.strings import _STRING_DTYPE

if TYPE_CHECKING:
from collections.abc import Awaitable, Callable, Iterator

Expand Down Expand Up @@ -162,3 +166,13 @@ def parse_order(data: Any) -> Literal["C", "F"]:
if data in ("C", "F"):
return cast(Literal["C", "F"], data)
raise ValueError(f"Expected one of ('C', 'F'), got {data} instead.")


def parse_dtype(dtype: Any, zarr_format: ZarrFormat) -> np.dtype[Any]:
if dtype is str or dtype == "str":
if zarr_format == 2:
# special case as object
return np.dtype("object")
else:
return _STRING_DTYPE
return np.dtype(dtype)
3 changes: 1 addition & 2 deletions src/zarr/core/metadata/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ def update_attributes(self, attributes: dict[str, JSON]) -> Self:


def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]:
# todo: real validation
return np.dtype(data)


Expand Down Expand Up @@ -317,7 +316,7 @@ def _default_fill_value(dtype: np.dtype[Any]) -> Any:
"""
if dtype.kind == "S":
return b""
elif dtype.kind == "U":
elif dtype.kind in "UO":
return ""
else:
return dtype.type(0)
2 changes: 1 addition & 1 deletion tests/v3/test_codecs/test_vlen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from zarr.core.strings import _NUMPY_SUPPORTS_VLEN_STRING
from zarr.storage.common import StorePath

numpy_str_dtypes: list[type | None] = [None, str, np.dtypes.StrDType]
numpy_str_dtypes: list[type | str | None] = [None, str, "str", np.dtypes.StrDType]
expected_zarr_string_dtype: np.dtype[Any]
if _NUMPY_SUPPORTS_VLEN_STRING:
numpy_str_dtypes.append(np.dtypes.StringDType)
Expand Down
16 changes: 14 additions & 2 deletions tests/v3/test_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_simple(store: StorePath) -> None:
("float64", 0.0),
("|S1", b""),
("|U1", ""),
("object", 0),
("object", ""),
TomAugspurger marked this conversation as resolved.
Show resolved Hide resolved
(str, ""),
],
)
Expand All @@ -53,7 +53,12 @@ def test_implicit_fill_value(store: StorePath, dtype: str, fill_value: Any) -> N
assert arr.metadata.fill_value is None
assert arr.metadata.to_dict()["fill_value"] is None
result = arr[:]
expected = np.full(arr.shape, fill_value, dtype=dtype)
if dtype is str:
# special case
numpy_dtype = np.dtype(object)
else:
numpy_dtype = np.dtype(dtype)
expected = np.full(arr.shape, fill_value, dtype=numpy_dtype)
np.testing.assert_array_equal(result, expected)


Expand Down Expand Up @@ -106,3 +111,10 @@ async def test_v2_encode_decode(dtype):
data = zarr.open_array(store=store, path="foo")[:]
expected = np.full((3,), b"X", dtype=dtype)
np.testing.assert_equal(data, expected)


@pytest.mark.parametrize("dtype", [str, "str"])
async def test_create_dtype_str(dtype: Any) -> None:
arr = zarr.create(shape=10, dtype=dtype, zarr_format=2)
assert arr.dtype.kind == "O"
assert arr.metadata.to_dict()["dtype"] == "|O"