Skip to content

Commit

Permalink
String dtype: implement object-dtype based StringArray variant with N…
Browse files Browse the repository at this point in the history
…umPy semantics (pandas-dev#58451)

Co-authored-by: Patrick Hoefler <61934744+phofl@users.noreply.github.com>
  • Loading branch information
2 people authored and jorisvandenbossche committed Oct 3, 2024
1 parent a9dc596 commit 5ee61c3
Show file tree
Hide file tree
Showing 14 changed files with 232 additions and 48 deletions.
2 changes: 1 addition & 1 deletion pandas/_libs/lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2728,7 +2728,7 @@ def maybe_convert_objects(ndarray[object] objects,
if using_string_dtype() and is_string_array(objects, skipna=True):
from pandas.core.arrays.string_ import StringDtype

dtype = StringDtype(storage="pyarrow", na_value=np.nan)
dtype = StringDtype(na_value=np.nan)
return dtype.construct_array_type()._from_sequence(objects, dtype=dtype)

seen.object_ = True
Expand Down
18 changes: 18 additions & 0 deletions pandas/_testing/asserters.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,24 @@ def assert_extension_array_equal(
left_na, right_na, obj=f"{obj} NA mask", index_values=index_values
)

# Specifically for StringArrayNumpySemantics, validate here we have a valid array
if (
isinstance(left.dtype, StringDtype)
and left.dtype.storage == "python"
and left.dtype.na_value is np.nan
):
assert np.all(
[np.isnan(val) for val in left._ndarray[left_na]] # type: ignore[attr-defined]
), "wrong missing value sentinels"
if (
isinstance(right.dtype, StringDtype)
and right.dtype.storage == "python"
and right.dtype.na_value is np.nan
):
assert np.all(
[np.isnan(val) for val in right._ndarray[right_na]] # type: ignore[attr-defined]
), "wrong missing value sentinels"

left_valid = left[~left_na].to_numpy(dtype=object)
right_valid = right[~right_na].to_numpy(dtype=object)
if check_exact:
Expand Down
2 changes: 2 additions & 0 deletions pandas/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import pandas.compat.compressors
from pandas.compat.numpy import is_numpy_dev
from pandas.compat.pyarrow import (
HAS_PYARROW,
pa_version_under10p1,
pa_version_under11p0,
pa_version_under13p0,
Expand Down Expand Up @@ -190,6 +191,7 @@ def get_bz2_file() -> type[pandas.compat.compressors.BZ2File]:
"pa_version_under14p1",
"pa_version_under16p0",
"pa_version_under17p0",
"HAS_PYARROW",
"IS64",
"ISMUSL",
"PY310",
Expand Down
2 changes: 2 additions & 0 deletions pandas/compat/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
pa_version_under15p0 = _palv < Version("15.0.0")
pa_version_under16p0 = _palv < Version("16.0.0")
pa_version_under17p0 = _palv < Version("17.0.0")
HAS_PYARROW = True
except ImportError:
pa_version_under10p1 = True
pa_version_under11p0 = True
Expand All @@ -27,3 +28,4 @@
pa_version_under15p0 = True
pa_version_under16p0 = True
pa_version_under17p0 = True
HAS_PYARROW = False
4 changes: 4 additions & 0 deletions pandas/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,6 +1265,7 @@ def string_storage(request):
("python", pd.NA),
pytest.param(("pyarrow", pd.NA), marks=td.skip_if_no("pyarrow")),
pytest.param(("pyarrow", np.nan), marks=td.skip_if_no("pyarrow")),
("python", np.nan),
]
)
def string_dtype_arguments(request):
Expand Down Expand Up @@ -1326,12 +1327,14 @@ def object_dtype(request):
("python", pd.NA),
pytest.param(("pyarrow", pd.NA), marks=td.skip_if_no("pyarrow")),
pytest.param(("pyarrow", np.nan), marks=td.skip_if_no("pyarrow")),
("python", np.nan),
],
ids=[
"string=object",
"string=string[python]",
"string=string[pyarrow]",
"string=str[pyarrow]",
"string=str[python]",
],
)
def any_string_dtype(request):
Expand All @@ -1341,6 +1344,7 @@ def any_string_dtype(request):
* 'string[python]' (NA variant)
* 'string[pyarrow]' (NA variant)
* 'str' (NaN variant, with pyarrow)
* 'str' (NaN variant, without pyarrow)
"""
if isinstance(request.param, np.dtype):
return request.param
Expand Down
Loading

0 comments on commit 5ee61c3

Please sign in to comment.