From 463fd91aecaa5083278a78d209790fb91f2daa25 Mon Sep 17 00:00:00 2001 From: Will Ayd Date: Wed, 14 Aug 2024 12:15:25 -0400 Subject: [PATCH] String dtype: implement object-dtype based StringArray variant with NumPy semantics (#58451) Co-authored-by: Patrick Hoefler <61934744+phofl@users.noreply.github.com> --- pandas/_libs/lib.pyx | 2 +- pandas/_testing/asserters.py | 18 ++ pandas/compat/__init__.py | 2 + pandas/compat/pyarrow.py | 2 + pandas/conftest.py | 4 + pandas/core/arrays/string_.py | 183 +++++++++++++++++++-- pandas/core/construction.py | 4 +- pandas/core/dtypes/cast.py | 2 +- pandas/core/internals/construction.py | 2 +- pandas/io/_util.py | 4 +- pandas/io/pytables.py | 7 +- pandas/tests/arrays/string_/test_string.py | 22 ++- pandas/tests/series/test_constructors.py | 26 ++- pandas/tests/strings/test_find_replace.py | 2 +- 14 files changed, 232 insertions(+), 48 deletions(-) diff --git a/pandas/_libs/lib.pyx b/pandas/_libs/lib.pyx index 1222b33aac3c1..5d8a04664b0e4 100644 --- a/pandas/_libs/lib.pyx +++ b/pandas/_libs/lib.pyx @@ -2728,7 +2728,7 @@ def maybe_convert_objects(ndarray[object] objects, if using_string_dtype() and is_string_array(objects, skipna=True): from pandas.core.arrays.string_ import StringDtype - dtype = StringDtype(storage="pyarrow", na_value=np.nan) + dtype = StringDtype(na_value=np.nan) return dtype.construct_array_type()._from_sequence(objects, dtype=dtype) seen.object_ = True diff --git a/pandas/_testing/asserters.py b/pandas/_testing/asserters.py index e79e353137152..a1f9844669c8c 100644 --- a/pandas/_testing/asserters.py +++ b/pandas/_testing/asserters.py @@ -811,6 +811,24 @@ def assert_extension_array_equal( left_na, right_na, obj=f"{obj} NA mask", index_values=index_values ) + # Specifically for StringArrayNumpySemantics, validate here we have a valid array + if ( + isinstance(left.dtype, StringDtype) + and left.dtype.storage == "python" + and left.dtype.na_value is np.nan + ): + assert np.all( + [np.isnan(val) for val in left._ndarray[left_na]] # type: ignore[attr-defined] + ), "wrong missing value sentinels" + if ( + isinstance(right.dtype, StringDtype) + and right.dtype.storage == "python" + and right.dtype.na_value is np.nan + ): + assert np.all( + [np.isnan(val) for val in right._ndarray[right_na]] # type: ignore[attr-defined] + ), "wrong missing value sentinels" + left_valid = left[~left_na].to_numpy(dtype=object) right_valid = right[~right_na].to_numpy(dtype=object) if check_exact: diff --git a/pandas/compat/__init__.py b/pandas/compat/__init__.py index 5ada6d705172f..38fb0188df5ff 100644 --- a/pandas/compat/__init__.py +++ b/pandas/compat/__init__.py @@ -25,6 +25,7 @@ import pandas.compat.compressors from pandas.compat.numpy import is_numpy_dev from pandas.compat.pyarrow import ( + HAS_PYARROW, pa_version_under10p1, pa_version_under11p0, pa_version_under13p0, @@ -190,6 +191,7 @@ def get_bz2_file() -> type[pandas.compat.compressors.BZ2File]: "pa_version_under14p1", "pa_version_under16p0", "pa_version_under17p0", + "HAS_PYARROW", "IS64", "ISMUSL", "PY310", diff --git a/pandas/compat/pyarrow.py b/pandas/compat/pyarrow.py index 457d26766520d..7fa197c4a9824 100644 --- a/pandas/compat/pyarrow.py +++ b/pandas/compat/pyarrow.py @@ -17,6 +17,7 @@ pa_version_under15p0 = _palv < Version("15.0.0") pa_version_under16p0 = _palv < Version("16.0.0") pa_version_under17p0 = _palv < Version("17.0.0") + HAS_PYARROW = True except ImportError: pa_version_under10p1 = True pa_version_under11p0 = True @@ -27,3 +28,4 @@ pa_version_under15p0 = True pa_version_under16p0 = True pa_version_under17p0 = True + HAS_PYARROW = False diff --git a/pandas/conftest.py b/pandas/conftest.py index 78cdc2ac5a2bb..433ea7275223d 100644 --- a/pandas/conftest.py +++ b/pandas/conftest.py @@ -1265,6 +1265,7 @@ def string_storage(request): ("python", pd.NA), pytest.param(("pyarrow", pd.NA), marks=td.skip_if_no("pyarrow")), pytest.param(("pyarrow", np.nan), marks=td.skip_if_no("pyarrow")), + ("python", np.nan), ] ) def string_dtype_arguments(request): @@ -1326,12 +1327,14 @@ def object_dtype(request): ("python", pd.NA), pytest.param(("pyarrow", pd.NA), marks=td.skip_if_no("pyarrow")), pytest.param(("pyarrow", np.nan), marks=td.skip_if_no("pyarrow")), + ("python", np.nan), ], ids=[ "string=object", "string=string[python]", "string=string[pyarrow]", "string=str[pyarrow]", + "string=str[python]", ], ) def any_string_dtype(request): @@ -1341,6 +1344,7 @@ def any_string_dtype(request): * 'string[python]' (NA variant) * 'string[pyarrow]' (NA variant) * 'str' (NaN variant, with pyarrow) + * 'str' (NaN variant, without pyarrow) """ if isinstance(request.param, np.dtype): return request.param diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index c4f208027c9da..6a3cd50b9c288 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -1,9 +1,12 @@ from __future__ import annotations +import operator from typing import ( TYPE_CHECKING, + Any, ClassVar, Literal, + cast, ) import numpy as np @@ -19,7 +22,10 @@ ) from pandas._libs.arrays import NDArrayBacked from pandas._libs.lib import ensure_string_array -from pandas.compat import pa_version_under10p1 +from pandas.compat import ( + HAS_PYARROW, + pa_version_under10p1, +) from pandas.compat.numpy import function as nv from pandas.util._decorators import doc @@ -37,7 +43,10 @@ pandas_dtype, ) -from pandas.core import ops +from pandas.core import ( + nanops, + ops, +) from pandas.core.array_algos import masked_reductions from pandas.core.arrays.base import ExtensionArray from pandas.core.arrays.floating import ( @@ -125,7 +134,10 @@ def __init__( # infer defaults if storage is None: if using_string_dtype() and na_value is not libmissing.NA: - storage = "pyarrow" + if HAS_PYARROW: + storage = "pyarrow" + else: + storage = "python" else: storage = get_option("mode.string_storage") @@ -243,10 +255,12 @@ def construct_array_type( # type: ignore[override] ArrowStringArrayNumpySemantics, ) - if self.storage == "python": + if self.storage == "python" and self._na_value is libmissing.NA: return StringArray elif self.storage == "pyarrow" and self._na_value is libmissing.NA: return ArrowStringArray + elif self.storage == "python": + return StringArrayNumpySemantics else: return ArrowStringArrayNumpySemantics @@ -282,7 +296,7 @@ def __from_arrow__( # convert chunk by chunk to numpy and concatenate then, to avoid # overflow for large string data when concatenating the pyarrow arrays arr = arr.to_numpy(zero_copy_only=False) - arr = ensure_string_array(arr, na_value=libmissing.NA) + arr = ensure_string_array(arr, na_value=self.na_value) results.append(arr) if len(chunks) == 0: @@ -292,11 +306,7 @@ def __from_arrow__( # Bypass validation inside StringArray constructor, see GH#47781 new_string_array = StringArray.__new__(StringArray) - NDArrayBacked.__init__( - new_string_array, - arr, - StringDtype(storage="python"), - ) + NDArrayBacked.__init__(new_string_array, arr, self) return new_string_array @@ -404,6 +414,8 @@ class StringArray(BaseStringArray, NumpyExtensionArray): # type: ignore[misc] # undo the NumpyExtensionArray hack _typ = "extension" + _storage = "python" + _na_value: libmissing.NAType | float = libmissing.NA def __init__(self, values, copy: bool = False) -> None: values = extract_array(values) @@ -411,7 +423,11 @@ def __init__(self, values, copy: bool = False) -> None: super().__init__(values, copy=copy) if not isinstance(values, type(self)): self._validate() - NDArrayBacked.__init__(self, self._ndarray, StringDtype(storage="python")) + NDArrayBacked.__init__( + self, + self._ndarray, + StringDtype(storage=self._storage, na_value=self._na_value), + ) def _validate(self): """Validate that we only store NA or strings.""" @@ -429,20 +445,36 @@ def _validate(self): else: lib.convert_nans_to_NA(self._ndarray) + def _validate_scalar(self, value): + # used by NDArrayBackedExtensionIndex.insert + if isna(value): + return self.dtype.na_value + elif not isinstance(value, str): + raise TypeError( + f"Cannot set non-string value '{value}' into a string array." + ) + return value + @classmethod def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy: bool = False): if dtype and not (isinstance(dtype, str) and dtype == "string"): dtype = pandas_dtype(dtype) assert isinstance(dtype, StringDtype) and dtype.storage == "python" + else: + if using_string_dtype(): + dtype = StringDtype(storage="python", na_value=np.nan) + else: + dtype = StringDtype(storage="python") from pandas.core.arrays.masked import BaseMaskedArray + na_value = dtype.na_value if isinstance(scalars, BaseMaskedArray): # avoid costly conversion to object dtype na_values = scalars._mask result = scalars._data result = lib.ensure_string_array(result, copy=copy, convert_na_value=False) - result[na_values] = libmissing.NA + result[na_values] = na_value else: if lib.is_pyarrow_array(scalars): @@ -451,12 +483,12 @@ def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy: bool = Fal # zero_copy_only to True which caused problems see GH#52076 scalars = np.array(scalars) # convert non-na-likes to str, and nan-likes to StringDtype().na_value - result = lib.ensure_string_array(scalars, na_value=libmissing.NA, copy=copy) + result = lib.ensure_string_array(scalars, na_value=na_value, copy=copy) # Manually creating new array avoids the validation step in the __init__, so is # faster. Refactor need for validation? new_string_array = cls.__new__(cls) - NDArrayBacked.__init__(new_string_array, result, StringDtype(storage="python")) + NDArrayBacked.__init__(new_string_array, result, dtype) return new_string_array @@ -506,7 +538,7 @@ def __setitem__(self, key, value) -> None: # validate new items if scalar_value: if isna(value): - value = libmissing.NA + value = self.dtype.na_value elif not isinstance(value, str): raise TypeError( f"Cannot set non-string value '{value}' into a StringArray." @@ -520,7 +552,7 @@ def __setitem__(self, key, value) -> None: mask = isna(value) if mask.any(): value = value.copy() - value[isna(value)] = libmissing.NA + value[isna(value)] = self.dtype.na_value super().__setitem__(key, value) @@ -633,9 +665,9 @@ def _cmp_method(self, other, op): if op.__name__ in ops.ARITHMETIC_BINOPS: result = np.empty_like(self._ndarray, dtype="object") - result[mask] = libmissing.NA + result[mask] = self.dtype.na_value result[valid] = op(self._ndarray[valid], other) - return StringArray(result) + return self._from_backing_data(result) else: # logical result = np.zeros(len(self._ndarray), dtype="bool") @@ -704,3 +736,118 @@ def _str_map( # or .findall returns a list). # -> We don't know the result type. E.g. `.get` can return anything. return lib.map_infer_mask(arr, f, mask.view("uint8")) + + +class StringArrayNumpySemantics(StringArray): + _storage = "python" + _na_value = np.nan + + def _validate(self) -> None: + """Validate that we only store NaN or strings.""" + if len(self._ndarray) and not lib.is_string_array(self._ndarray, skipna=True): + raise ValueError( + "StringArrayNumpySemantics requires a sequence of strings or NaN" + ) + if self._ndarray.dtype != "object": + raise ValueError( + "StringArrayNumpySemantics requires a sequence of strings or NaN. Got " + f"'{self._ndarray.dtype}' dtype instead." + ) + # TODO validate or force NA/None to NaN + + @classmethod + def _from_sequence( + cls, scalars, *, dtype: Dtype | None = None, copy: bool = False + ) -> Self: + if dtype is None: + dtype = StringDtype(storage="python", na_value=np.nan) + return super()._from_sequence(scalars, dtype=dtype, copy=copy) + + def _from_backing_data(self, arr: np.ndarray) -> StringArrayNumpySemantics: + # need to override NumpyExtensionArray._from_backing_data to ensure + # we always preserve the dtype + return NDArrayBacked._from_backing_data(self, arr) + + def _reduce( + self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs + ): + if name in ["any", "all"]: + if name == "any": + return nanops.nanany(self._ndarray, skipna=skipna) + else: + return nanops.nanall(self._ndarray, skipna=skipna) + else: + return super()._reduce(name, skipna=skipna, keepdims=keepdims, **kwargs) + + def _wrap_reduction_result(self, axis: AxisInt | None, result) -> Any: + # the masked_reductions use pd.NA + if result is libmissing.NA: + return np.nan + return super()._wrap_reduction_result(axis, result) + + def _cmp_method(self, other, op): + result = super()._cmp_method(other, op) + if op == operator.ne: + return result.to_numpy(np.bool_, na_value=True) + else: + return result.to_numpy(np.bool_, na_value=False) + + def value_counts(self, dropna: bool = True) -> Series: + from pandas.core.algorithms import value_counts_internal as value_counts + + result = value_counts(self._ndarray, sort=False, dropna=dropna) + result.index = result.index.astype(self.dtype) + return result + + # ------------------------------------------------------------------------ + # String methods interface + _str_na_value = np.nan + + def _str_map( + self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True + ): + if dtype is None: + dtype = self.dtype + if na_value is None: + na_value = self.dtype.na_value + + mask = isna(self) + arr = np.asarray(self) + convert = convert and not np.all(mask) + + if is_integer_dtype(dtype) or is_bool_dtype(dtype): + na_value_is_na = isna(na_value) + if na_value_is_na: + if is_integer_dtype(dtype): + na_value = 0 + else: + na_value = True + + result = lib.map_infer_mask( + arr, + f, + mask.view("uint8"), + convert=False, + na_value=na_value, + dtype=np.dtype(cast(type, dtype)), + ) + if na_value_is_na and mask.any(): + if is_integer_dtype(dtype): + result = result.astype("float64") + else: + result = result.astype("object") + result[mask] = np.nan + return result + + elif is_string_dtype(dtype) and not is_object_dtype(dtype): + # i.e. StringDtype + result = lib.map_infer_mask( + arr, f, mask.view("uint8"), convert=False, na_value=na_value + ) + return type(self)(result) + else: + # This is when the result type is object. We reach this when + # -> We know the result type is truly object (e.g. .encode returns bytes + # or .findall returns a list). + # -> We don't know the result type. E.g. `.get` can return anything. + return lib.map_infer_mask(arr, f, mask.view("uint8")) diff --git a/pandas/core/construction.py b/pandas/core/construction.py index 748b9e4947bec..5bccca9cfbd47 100644 --- a/pandas/core/construction.py +++ b/pandas/core/construction.py @@ -569,7 +569,7 @@ def sanitize_array( if isinstance(data, str) and using_string_dtype() and original_dtype is None: from pandas.core.arrays.string_ import StringDtype - dtype = StringDtype("pyarrow", na_value=np.nan) + dtype = StringDtype(na_value=np.nan) data = construct_1d_arraylike_from_scalar(data, len(index), dtype) return data @@ -606,7 +606,7 @@ def sanitize_array( elif data.dtype.kind == "U" and using_string_dtype(): from pandas.core.arrays.string_ import StringDtype - dtype = StringDtype(storage="pyarrow", na_value=np.nan) + dtype = StringDtype(na_value=np.nan) subarr = dtype.construct_array_type()._from_sequence(data, dtype=dtype) if subarr is data and copy: diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 5d4e56cd8f800..9580ab1b520e0 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -801,7 +801,7 @@ def infer_dtype_from_scalar(val) -> tuple[DtypeObj, Any]: if using_string_dtype(): from pandas.core.arrays.string_ import StringDtype - dtype = StringDtype(storage="pyarrow", na_value=np.nan) + dtype = StringDtype(na_value=np.nan) elif isinstance(val, (np.datetime64, dt.datetime)): try: diff --git a/pandas/core/internals/construction.py b/pandas/core/internals/construction.py index 144416fc11691..f3dbacc02bec9 100644 --- a/pandas/core/internals/construction.py +++ b/pandas/core/internals/construction.py @@ -376,7 +376,7 @@ def ndarray_to_mgr( nb = new_block_2d(values, placement=bp, refs=refs) block_values = [nb] elif dtype is None and values.dtype.kind == "U" and using_string_dtype(): - dtype = StringDtype(storage="pyarrow", na_value=np.nan) + dtype = StringDtype(na_value=np.nan) obj_columns = list(values) block_values = [ diff --git a/pandas/io/_util.py b/pandas/io/_util.py index 2dc3b0e6c80ef..68fcfcf65e0c2 100644 --- a/pandas/io/_util.py +++ b/pandas/io/_util.py @@ -31,6 +31,6 @@ def arrow_string_types_mapper() -> Callable: pa = import_optional_dependency("pyarrow") return { - pa.string(): pd.StringDtype(storage="pyarrow", na_value=np.nan), - pa.large_string(): pd.StringDtype(storage="pyarrow", na_value=np.nan), + pa.string(): pd.StringDtype(na_value=np.nan), + pa.large_string(): pd.StringDtype(na_value=np.nan), }.get diff --git a/pandas/io/pytables.py b/pandas/io/pytables.py index 12bb93a63f850..2e38303caa354 100644 --- a/pandas/io/pytables.py +++ b/pandas/io/pytables.py @@ -76,6 +76,7 @@ PeriodIndex, RangeIndex, Series, + StringDtype, TimedeltaIndex, concat, isna, @@ -3225,7 +3226,7 @@ def read( values = self.read_array("values", start=start, stop=stop) result = Series(values, index=index, name=self.name, copy=False) if using_string_dtype() and is_string_array(values, skipna=True): - result = result.astype("string[pyarrow_numpy]") + result = result.astype(StringDtype(na_value=np.nan)) return result def write(self, obj, **kwargs) -> None: @@ -3294,7 +3295,7 @@ def read( columns = items[items.get_indexer(blk_items)] df = DataFrame(values.T, columns=columns, index=axes[1], copy=False) if using_string_dtype() and is_string_array(values, skipna=True): - df = df.astype("string[pyarrow_numpy]") + df = df.astype(StringDtype(na_value=np.nan)) dfs.append(df) if len(dfs) > 0: @@ -4685,7 +4686,7 @@ def read( values, # type: ignore[arg-type] skipna=True, ): - df = df.astype("string[pyarrow_numpy]") + df = df.astype(StringDtype(na_value=np.nan)) frames.append(df) if len(frames) == 1: diff --git a/pandas/tests/arrays/string_/test_string.py b/pandas/tests/arrays/string_/test_string.py index 6bf20b6fcd5f7..b51b01c2b5168 100644 --- a/pandas/tests/arrays/string_/test_string.py +++ b/pandas/tests/arrays/string_/test_string.py @@ -15,6 +15,7 @@ import pandas as pd import pandas._testing as tm +from pandas.core.arrays.string_ import StringArrayNumpySemantics from pandas.core.arrays.string_arrow import ( ArrowStringArray, ArrowStringArrayNumpySemantics, @@ -75,6 +76,9 @@ def test_repr(dtype): elif dtype.storage == "pyarrow" and dtype.na_value is np.nan: arr_name = "ArrowStringArrayNumpySemantics" expected = f"<{arr_name}>\n['a', nan, 'b']\nLength: 3, dtype: string" + elif dtype.storage == "python" and dtype.na_value is np.nan: + arr_name = "StringArrayNumpySemantics" + expected = f"<{arr_name}>\n['a', nan, 'b']\nLength: 3, dtype: string" else: arr_name = "StringArray" expected = f"<{arr_name}>\n['a', , 'b']\nLength: 3, dtype: string" @@ -90,14 +94,14 @@ def test_none_to_nan(cls, dtype): def test_setitem_validates(cls, dtype): arr = cls._from_sequence(["a", "b"], dtype=dtype) - if cls is pd.arrays.StringArray: + if dtype.storage == "python": msg = "Cannot set non-string value '10' into a StringArray." else: msg = "Scalar must be NA or str" with pytest.raises(TypeError, match=msg): arr[0] = 10 - if cls is pd.arrays.StringArray: + if dtype.storage == "python": msg = "Must provide strings." else: msg = "Scalar must be NA or str" @@ -339,6 +343,8 @@ def test_comparison_methods_array(comparison_op, dtype): def test_constructor_raises(cls): if cls is pd.arrays.StringArray: msg = "StringArray requires a sequence of strings or pandas.NA" + elif cls is StringArrayNumpySemantics: + msg = "StringArrayNumpySemantics requires a sequence of strings or NaN" else: msg = "Unsupported type '' for ArrowExtensionArray" @@ -348,7 +354,7 @@ def test_constructor_raises(cls): with pytest.raises(ValueError, match=msg): cls(np.array([])) - if cls is pd.arrays.StringArray: + if cls is pd.arrays.StringArray or cls is StringArrayNumpySemantics: # GH#45057 np.nan and None do NOT raise, as they are considered valid NAs # for string dtype cls(np.array(["a", np.nan], dtype=object)) @@ -389,6 +395,8 @@ def test_from_sequence_no_mutate(copy, cls, dtype): import pyarrow as pa expected = cls(pa.array(na_arr, type=pa.string(), from_pandas=True)) + elif cls is StringArrayNumpySemantics: + expected = cls(nan_arr) else: expected = cls(na_arr) @@ -671,7 +679,11 @@ def test_isin(dtype, fixed_now_ts): tm.assert_series_equal(result, expected) result = s.isin(["a", pd.NA]) - expected = pd.Series([True, False, True]) + if dtype.storage == "python" and dtype.na_value is np.nan: + # TODO(infer_string) we should make this consistent + expected = pd.Series([True, False, False]) + else: + expected = pd.Series([True, False, True]) tm.assert_series_equal(result, expected) result = s.isin([]) @@ -695,7 +707,7 @@ def test_setitem_scalar_with_mask_validation(dtype): # for other non-string we should also raise an error ser = pd.Series(["a", "b", "c"], dtype=dtype) - if type(ser.array) is pd.arrays.StringArray: + if dtype.storage == "python": msg = "Cannot set non-string value" else: msg = "Scalar must be NA or str" diff --git a/pandas/tests/series/test_constructors.py b/pandas/tests/series/test_constructors.py index 0c39cead78baf..0aaa8ddcfda0c 100644 --- a/pandas/tests/series/test_constructors.py +++ b/pandas/tests/series/test_constructors.py @@ -14,6 +14,7 @@ iNaT, lib, ) +from pandas.compat import HAS_PYARROW from pandas.compat.numpy import np_version_gt2 from pandas.errors import IntCastingNaNError import pandas.util._test_decorators as td @@ -2094,11 +2095,10 @@ def test_series_from_index_dtype_equal_does_not_copy(self): def test_series_string_inference(self): # GH#54430 - pytest.importorskip("pyarrow") - dtype = "string[pyarrow_numpy]" - expected = Series(["a", "b"], dtype=dtype) with pd.option_context("future.infer_string", True): ser = Series(["a", "b"]) + dtype = pd.StringDtype("pyarrow" if HAS_PYARROW else "python", na_value=np.nan) + expected = Series(["a", "b"], dtype=dtype) tm.assert_series_equal(ser, expected) expected = Series(["a", 1], dtype="object") @@ -2109,35 +2109,33 @@ def test_series_string_inference(self): @pytest.mark.parametrize("na_value", [None, np.nan, pd.NA]) def test_series_string_with_na_inference(self, na_value): # GH#54430 - pytest.importorskip("pyarrow") - dtype = "string[pyarrow_numpy]" - expected = Series(["a", na_value], dtype=dtype) with pd.option_context("future.infer_string", True): ser = Series(["a", na_value]) + dtype = pd.StringDtype("pyarrow" if HAS_PYARROW else "python", na_value=np.nan) + expected = Series(["a", None], dtype=dtype) tm.assert_series_equal(ser, expected) def test_series_string_inference_scalar(self): # GH#54430 - pytest.importorskip("pyarrow") - expected = Series("a", index=[1], dtype="string[pyarrow_numpy]") with pd.option_context("future.infer_string", True): ser = Series("a", index=[1]) + dtype = pd.StringDtype("pyarrow" if HAS_PYARROW else "python", na_value=np.nan) + expected = Series("a", index=[1], dtype=dtype) tm.assert_series_equal(ser, expected) def test_series_string_inference_array_string_dtype(self): # GH#54496 - pytest.importorskip("pyarrow") - expected = Series(["a", "b"], dtype="string[pyarrow_numpy]") with pd.option_context("future.infer_string", True): ser = Series(np.array(["a", "b"])) + dtype = pd.StringDtype("pyarrow" if HAS_PYARROW else "python", na_value=np.nan) + expected = Series(["a", "b"], dtype=dtype) tm.assert_series_equal(ser, expected) def test_series_string_inference_storage_definition(self): # https://github.com/pandas-dev/pandas/issues/54793 # but after PDEP-14 (string dtype), it was decided to keep dtype="string" # returning the NA string dtype, so expected is changed from - # "string[pyarrow_numpy]" to "string[pyarrow]" - pytest.importorskip("pyarrow") + # "string[pyarrow_numpy]" to "string[python]" expected = Series(["a", "b"], dtype="string[python]") with pd.option_context("future.infer_string", True): result = Series(["a", "b"], dtype="string") @@ -2153,10 +2151,10 @@ def test_series_constructor_infer_string_scalar(self): def test_series_string_inference_na_first(self): # GH#55655 - pytest.importorskip("pyarrow") - expected = Series([pd.NA, "b"], dtype="string[pyarrow_numpy]") with pd.option_context("future.infer_string", True): result = Series([pd.NA, "b"]) + dtype = pd.StringDtype("pyarrow" if HAS_PYARROW else "python", na_value=np.nan) + expected = Series([None, "b"], dtype=dtype) tm.assert_series_equal(result, expected) def test_inference_on_pandas_objects(self): diff --git a/pandas/tests/strings/test_find_replace.py b/pandas/tests/strings/test_find_replace.py index df490297e2a5c..2d7c9754ee319 100644 --- a/pandas/tests/strings/test_find_replace.py +++ b/pandas/tests/strings/test_find_replace.py @@ -236,7 +236,7 @@ def test_contains_nan(any_string_dtype): result = s.str.contains("foo", na="foo") if any_string_dtype == "object": expected = Series(["foo", "foo", "foo"], dtype=np.object_) - elif any_string_dtype == "string[pyarrow_numpy]": + elif any_string_dtype.na_value is np.nan: expected = Series([True, True, True], dtype=np.bool_) else: expected = Series([True, True, True], dtype="boolean")