From 7e76e9e392c07fbf854a2a1892211bba9e86566b Mon Sep 17 00:00:00 2001
From: Tom Augspurger <tom.w.augspurger@gmail.com>
Date: Wed, 9 Oct 2024 10:08:13 -0500
Subject: [PATCH] fixup

---
 src/zarr/core/array.py            |  2 +-
 src/zarr/core/common.py           | 11 ++++++++---
 src/zarr/core/metadata/v2.py      |  4 ++--
 src/zarr/core/metadata/v3.py      |  4 ----
 tests/v3/test_array.py            |  9 +--------
 tests/v3/test_codecs/test_vlen.py |  2 +-
 tests/v3/test_v2.py               |  8 ++++++++
 7 files changed, 21 insertions(+), 19 deletions(-)

diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py
index 12a94c64ae..a80e15e896 100644
--- a/src/zarr/core/array.py
+++ b/src/zarr/core/array.py
@@ -227,7 +227,7 @@ async def create(
         if chunks is not None and chunk_shape is not None:
             raise ValueError("Only one of chunk_shape or chunks can be provided.")
 
-        dtype = parse_dtype(dtype)
+        dtype = parse_dtype(dtype, zarr_format)
         # dtype = np.dtype(dtype)
         if chunks:
             _chunks = normalize_chunks(chunks, shape, dtype.itemsize)
diff --git a/src/zarr/core/common.py b/src/zarr/core/common.py
index 93485d1b73..d74b07ebce 100644
--- a/src/zarr/core/common.py
+++ b/src/zarr/core/common.py
@@ -18,6 +18,8 @@
 
 import numpy as np
 
+from zarr.core.strings import _STRING_DTYPE
+
 if TYPE_CHECKING:
     from collections.abc import Awaitable, Callable, Iterator
 
@@ -166,8 +168,11 @@ def parse_order(data: Any) -> Literal["C", "F"]:
     raise ValueError(f"Expected one of ('C', 'F'), got {data} instead.")
 
 
-def parse_dtype(dtype: Any) -> np.dtype[Any]:
+def parse_dtype(dtype: Any, zarr_format: ZarrFormat) -> np.dtype[Any]:
     if dtype is str or dtype == "str":
-        # special case as object
-        return np.dtype("object")
+        if zarr_format == 2:
+            # special case as object
+            return np.dtype("object")
+        else:
+            return _STRING_DTYPE
     return np.dtype(dtype)
diff --git a/src/zarr/core/metadata/v2.py b/src/zarr/core/metadata/v2.py
index 5ac4f9b361..ddbf0aa94d 100644
--- a/src/zarr/core/metadata/v2.py
+++ b/src/zarr/core/metadata/v2.py
@@ -57,7 +57,7 @@ def __init__(
         Metadata for a Zarr version 2 array.
         """
         shape_parsed = parse_shapelike(shape)
-        data_type_parsed = parse_dtype(dtype)
+        data_type_parsed = parse_dtype(dtype, zarr_format=2)
         chunks_parsed = parse_shapelike(chunks)
         compressor_parsed = parse_compressor(compressor)
         order_parsed = parse_indexing_order(order)
@@ -141,7 +141,7 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata:
         _data = data.copy()
         # check that the zarr_format attribute is correct
         _ = parse_zarr_format(_data.pop("zarr_format"))
-        dtype = parse_dtype(_data["dtype"])
+        dtype = parse_dtype(_data["dtype"], zarr_format=2)
 
         if dtype.kind in "SV":
             fill_value_encoded = _data.get("fill_value")
diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py
index d77646249d..47c6106bfe 100644
--- a/src/zarr/core/metadata/v3.py
+++ b/src/zarr/core/metadata/v3.py
@@ -504,7 +504,6 @@ class DataType(Enum):
     complex128 = "complex128"
     string = "string"
     bytes = "bytes"
-    object = "object"
 
     @property
     def byte_count(self) -> None | int:
@@ -550,7 +549,6 @@ def to_numpy_shortname(self) -> str:
             DataType.float64: "f8",
             DataType.complex64: "c8",
             DataType.complex128: "c16",
-            DataType.object: "object",
         }
         return data_type_to_numpy[self]
 
@@ -574,8 +572,6 @@ def from_numpy(cls, dtype: np.dtype[Any]) -> DataType:
             return DataType.string
         elif dtype.kind == "S":
             return DataType.bytes
-        elif dtype.kind == "O":
-            return DataType.object
         dtype_to_data_type = {
             "|b1": "bool",
             "bool": "bool",
diff --git a/tests/v3/test_array.py b/tests/v3/test_array.py
index fa15cd08d7..04adb2a224 100644
--- a/tests/v3/test_array.py
+++ b/tests/v3/test_array.py
@@ -1,6 +1,6 @@
 import pickle
 from itertools import accumulate
-from typing import Any, Literal
+from typing import Literal
 
 import numpy as np
 import pytest
@@ -406,10 +406,3 @@ def test_vlen_errors() -> None:
             dtype="<U4",
             codecs=[BytesCodec(), VLenBytesCodec()],
         )
-
-
-@pytest.mark.parametrize("zarr_format", [2, 3, None])
-@pytest.mark.parametrize("dtype", [str, "str"])
-def test_create_dtype_str(dtype: Any, zarr_format: ZarrFormat | None) -> None:
-    arr = zarr.create(shape=10, dtype=dtype, zarr_format=zarr_format)
-    assert arr.dtype.kind == "O"
diff --git a/tests/v3/test_codecs/test_vlen.py b/tests/v3/test_codecs/test_vlen.py
index c6f587931a..aaea5dab83 100644
--- a/tests/v3/test_codecs/test_vlen.py
+++ b/tests/v3/test_codecs/test_vlen.py
@@ -11,7 +11,7 @@
 from zarr.core.strings import _NUMPY_SUPPORTS_VLEN_STRING
 from zarr.storage.common import StorePath
 
-numpy_str_dtypes: list[type | None] = [None, str, np.dtypes.StrDType]
+numpy_str_dtypes: list[type | str | None] = [None, str, "str", np.dtypes.StrDType]
 expected_zarr_string_dtype: np.dtype[Any]
 if _NUMPY_SUPPORTS_VLEN_STRING:
     numpy_str_dtypes.append(np.dtypes.StringDType)
diff --git a/tests/v3/test_v2.py b/tests/v3/test_v2.py
index f488782d78..60e59e49bf 100644
--- a/tests/v3/test_v2.py
+++ b/tests/v3/test_v2.py
@@ -1,5 +1,6 @@
 import json
 from collections.abc import Iterator
+from typing import Any
 
 import numpy as np
 import pytest
@@ -84,3 +85,10 @@ async def test_v2_encode_decode(dtype):
     data = zarr.open_array(store=store, path="foo")[:]
     expected = np.full((3,), b"X", dtype=dtype)
     np.testing.assert_equal(data, expected)
+
+
+@pytest.mark.parametrize("dtype", [str, "str"])
+async def test_create_dtype_str(dtype: Any) -> None:
+    arr = zarr.create(shape=10, dtype=dtype, zarr_format=2)
+    assert arr.dtype.kind == "O"
+    assert arr.metadata.to_dict()["dtype"] == "|O"