Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Special case str dtype in array creation #2323

Merged
2 changes: 1 addition & 1 deletion src/zarr/codecs/_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def _decode_single(
chunk_numpy_array = ensure_ndarray(chunk_bytes.as_array_like())

# ensure correct dtype
if str(chunk_numpy_array.dtype) != chunk_spec.dtype:
if str(chunk_numpy_array.dtype) != chunk_spec.dtype and not chunk_spec.dtype.hasobject:
chunk_numpy_array = chunk_numpy_array.view(chunk_spec.dtype)

return get_ndbuffer_class().from_numpy_array(chunk_numpy_array)
Expand Down
20 changes: 15 additions & 5 deletions src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
ShapeLike,
ZarrFormat,
concurrent_map,
parse_dtype,
parse_shapelike,
product,
)
Expand Down Expand Up @@ -365,16 +366,17 @@ async def create(
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]:
store_path = await make_store_path(store)

dtype_parsed = parse_dtype(dtype, zarr_format)
shape = parse_shapelike(shape)

if chunks is not None and chunk_shape is not None:
raise ValueError("Only one of chunk_shape or chunks can be provided.")

dtype = np.dtype(dtype)
if chunks:
_chunks = normalize_chunks(chunks, shape, dtype.itemsize)
_chunks = normalize_chunks(chunks, shape, dtype_parsed.itemsize)
else:
_chunks = normalize_chunks(chunk_shape, shape, dtype.itemsize)
_chunks = normalize_chunks(chunk_shape, shape, dtype_parsed.itemsize)

result: AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata]
if zarr_format == 3:
if dimension_separator is not None:
Expand All @@ -396,7 +398,7 @@ async def create(
result = await cls._create_v3(
store_path,
shape=shape,
dtype=dtype,
dtype=dtype_parsed,
chunk_shape=_chunks,
fill_value=fill_value,
chunk_key_encoding=chunk_key_encoding,
Expand All @@ -406,6 +408,14 @@ async def create(
exists_ok=exists_ok,
)
elif zarr_format == 2:
if dtype is str or dtype == "str":
# another special case: zarr v2 added the vlen-utf8 codec
vlen_codec: dict[str, JSON] = {"id": "vlen-utf8"}
if filters and not any(x["id"] == "vlen-utf8" for x in filters):
filters = list(filters) + [vlen_codec]
else:
filters = [vlen_codec]

if codecs is not None:
raise ValueError(
"codecs cannot be used for arrays with version 2. Use filters and compressor instead."
Expand All @@ -419,7 +429,7 @@ async def create(
result = await cls._create_v2(
store_path,
shape=shape,
dtype=dtype,
dtype=dtype_parsed,
chunks=_chunks,
dimension_separator=dimension_separator,
fill_value=fill_value,
Expand Down
14 changes: 14 additions & 0 deletions src/zarr/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
overload,
)

import numpy as np

from zarr.core.strings import _STRING_DTYPE

if TYPE_CHECKING:
from collections.abc import Awaitable, Callable, Iterator

Expand Down Expand Up @@ -151,3 +155,13 @@ def parse_order(data: Any) -> Literal["C", "F"]:
if data in ("C", "F"):
return cast(Literal["C", "F"], data)
raise ValueError(f"Expected one of ('C', 'F'), got {data} instead.")


def parse_dtype(dtype: Any, zarr_format: ZarrFormat) -> np.dtype[Any]:
if dtype is str or dtype == "str":
if zarr_format == 2:
# special case as object
return np.dtype("object")
else:
return _STRING_DTYPE
return np.dtype(dtype)
2 changes: 1 addition & 1 deletion src/zarr/core/metadata/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def _default_fill_value(dtype: np.dtype[Any]) -> Any:
"""
if dtype.kind == "S":
return b""
elif dtype.kind == "U":
elif dtype.kind in "UO":
return ""
else:
return dtype.type(0)
2 changes: 1 addition & 1 deletion tests/v3/test_codecs/test_vlen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from zarr.core.strings import _NUMPY_SUPPORTS_VLEN_STRING
from zarr.storage.common import StorePath

numpy_str_dtypes: list[type | str | None] = [None, str, np.dtypes.StrDType]
numpy_str_dtypes: list[type | str | None] = [None, str, "str", np.dtypes.StrDType]
expected_zarr_string_dtype: np.dtype[Any]
if _NUMPY_SUPPORTS_VLEN_STRING:
numpy_str_dtypes.append(np.dtypes.StringDType)
Expand Down
21 changes: 19 additions & 2 deletions tests/v3/test_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections.abc import Iterator
from typing import Any

import numcodecs.vlen
import numpy as np
import pytest
from numcodecs import Delta
Expand Down Expand Up @@ -44,7 +45,7 @@ def test_simple(store: StorePath) -> None:
("float64", 0.0),
("|S1", b""),
("|U1", ""),
("object", 0),
("object", ""),
TomAugspurger marked this conversation as resolved.
Show resolved Hide resolved
(str, ""),
],
)
Expand All @@ -53,7 +54,12 @@ def test_implicit_fill_value(store: StorePath, dtype: str, fill_value: Any) -> N
assert arr.metadata.fill_value is None
assert arr.metadata.to_dict()["fill_value"] is None
result = arr[:]
expected = np.full(arr.shape, fill_value, dtype=dtype)
if dtype is str:
# special case
numpy_dtype = np.dtype(object)
else:
numpy_dtype = np.dtype(dtype)
expected = np.full(arr.shape, fill_value, dtype=numpy_dtype)
np.testing.assert_array_equal(result, expected)


Expand Down Expand Up @@ -106,3 +112,14 @@ async def test_v2_encode_decode(dtype):
data = zarr.open_array(store=store, path="foo")[:]
expected = np.full((3,), b"X", dtype=dtype)
np.testing.assert_equal(data, expected)


@pytest.mark.parametrize("dtype", [str, "str"])
async def test_create_dtype_str(dtype: Any) -> None:
arr = zarr.create(shape=3, dtype=dtype, zarr_format=2)
assert arr.dtype.kind == "O"
assert arr.metadata.to_dict()["dtype"] == "|O"
assert arr.metadata.filters == (numcodecs.vlen.VLenUTF8(),)
arr[:] = ["a", "bb", "ccc"]
result = arr[:]
np.testing.assert_array_equal(result, np.array(["a", "bb", "ccc"], dtype="object"))