diff --git a/doc/source/user_guide/text.rst b/doc/source/user_guide/text.rst index d521c745ccfe5..072871f89bdae 100644 --- a/doc/source/user_guide/text.rst +++ b/doc/source/user_guide/text.rst @@ -13,7 +13,7 @@ Text Data Types .. versionadded:: 1.0.0 -There are two main ways to store text data +There are two ways to store text data in pandas: 1. ``object`` -dtype NumPy array. 2. :class:`StringDtype` extension type. @@ -63,7 +63,40 @@ Or ``astype`` after the ``Series`` or ``DataFrame`` is created s s.astype("string") -Everything that follows in the rest of this document applies equally to +.. _text.differences: + +Behavior differences +^^^^^^^^^^^^^^^^^^^^ + +These are places where the behavior of ``StringDtype`` objects differ from +``object`` dtype + +l. For ``StringDtype``, :ref:`string accessor methods` + that return **numeric** output will always return a nullable integer dtype, + rather than either int or float dtype, depending on the presence of NA values. + + .. ipython:: python + + s = pd.Series(["a", None, "b"], dtype="string") + s + s.str.count("a") + s.dropna().str.count("a") + + Both outputs are ``Int64`` dtype. Compare that with object-dtype + + .. ipython:: python + + s.astype(object).str.count("a") + s.astype(object).dropna().str.count("a") + + When NA values are present, the output dtype is float64. + +2. Some string methods, like :meth:`Series.str.decode` are not available + on ``StringArray`` because ``StringArray`` only holds strings, not + bytes. + + +Everything else that follows in the rest of this document applies equally to ``string`` and ``object`` dtype. .. _text.string_methods: diff --git a/doc/source/whatsnew/v1.0.0.rst b/doc/source/whatsnew/v1.0.0.rst index 54e54751a1f89..4ce21139d5131 100644 --- a/doc/source/whatsnew/v1.0.0.rst +++ b/doc/source/whatsnew/v1.0.0.rst @@ -63,7 +63,7 @@ Previously, strings were typically stored in object-dtype NumPy arrays. ``StringDtype`` is currently considered experimental. The implementation and parts of the API may change without warning. -The text extension type solves several issues with object-dtype NumPy arrays: +The ``'string'`` extension type solves several issues with object-dtype NumPy arrays: 1. You can accidentally store a *mixture* of strings and non-strings in an ``object`` dtype array. A ``StringArray`` can only store strings. @@ -88,9 +88,17 @@ You can use the alias ``"string"`` as well. The usual string accessor methods work. Where appropriate, the return type of the Series or columns of a DataFrame will also have string dtype. +.. ipython:: python + s.str.upper() s.str.split('b', expand=True).dtypes +String accessor methods returning integers will return a value with :class:`Int64Dtype` + +.. ipython:: python + + s.str.count("a") + We recommend explicitly using the ``string`` data type when working with strings. See :ref:`text.types` for more. diff --git a/pandas/_libs/lib.pyx b/pandas/_libs/lib.pyx index c1fd46f4bba9e..aaf6456df8f8e 100644 --- a/pandas/_libs/lib.pyx +++ b/pandas/_libs/lib.pyx @@ -2208,9 +2208,13 @@ def maybe_convert_objects(ndarray[object] objects, bint try_float=0, return objects +_no_default = object() + + @cython.boundscheck(False) @cython.wraparound(False) -def map_infer_mask(ndarray arr, object f, const uint8_t[:] mask, bint convert=1): +def map_infer_mask(ndarray arr, object f, const uint8_t[:] mask, bint convert=1, + object na_value=_no_default, object dtype=object): """ Substitute for np.vectorize with pandas-friendly dtype inference @@ -2218,6 +2222,15 @@ def map_infer_mask(ndarray arr, object f, const uint8_t[:] mask, bint convert=1) ---------- arr : ndarray f : function + mask : ndarray + uint8 dtype ndarray indicating values not to apply `f` to. + convert : bool, default True + Whether to call `maybe_convert_objects` on the resulting ndarray + na_value : Any, optional + The result value to use for masked values. By default, the + input value is used + dtype : numpy.dtype + The numpy dtype to use for the result ndarray. Returns ------- @@ -2225,14 +2238,17 @@ def map_infer_mask(ndarray arr, object f, const uint8_t[:] mask, bint convert=1) """ cdef: Py_ssize_t i, n - ndarray[object] result + ndarray result object val n = len(arr) - result = np.empty(n, dtype=object) + result = np.empty(n, dtype=dtype) for i in range(n): if mask[i]: - val = arr[i] + if na_value is _no_default: + val = arr[i] + else: + val = na_value else: val = f(arr[i]) diff --git a/pandas/core/strings.py b/pandas/core/strings.py index 55ce44d736864..413e7e85eb6fe 100644 --- a/pandas/core/strings.py +++ b/pandas/core/strings.py @@ -2,7 +2,7 @@ from functools import wraps import re import textwrap -from typing import Dict, List +from typing import TYPE_CHECKING, Any, Callable, Dict, List import warnings import numpy as np @@ -15,10 +15,14 @@ ensure_object, is_bool_dtype, is_categorical_dtype, + is_extension_array_dtype, is_integer, + is_integer_dtype, is_list_like, + is_object_dtype, is_re, is_scalar, + is_string_dtype, ) from pandas.core.dtypes.generic import ( ABCDataFrame, @@ -28,9 +32,14 @@ ) from pandas.core.dtypes.missing import isna +from pandas._typing import ArrayLike, Dtype from pandas.core.algorithms import take_1d from pandas.core.base import NoNewAttributesMixin import pandas.core.common as com +from pandas.core.construction import extract_array + +if TYPE_CHECKING: + from pandas.arrays import StringArray _cpython_optimized_encoders = ( "utf-8", @@ -109,10 +118,79 @@ def cat_safe(list_of_columns: List, sep: str): def _na_map(f, arr, na_result=np.nan, dtype=object): # should really _check_ for NA - return _map(f, arr, na_mask=True, na_value=na_result, dtype=dtype) + if is_extension_array_dtype(arr.dtype): + # just StringDtype + arr = extract_array(arr) + return _map_stringarray(f, arr, na_value=na_result, dtype=dtype) + return _map_object(f, arr, na_mask=True, na_value=na_result, dtype=dtype) + + +def _map_stringarray( + func: Callable[[str], Any], arr: "StringArray", na_value: Any, dtype: Dtype +) -> ArrayLike: + """ + Map a callable over valid elements of a StringArrray. + + Parameters + ---------- + func : Callable[[str], Any] + Apply to each valid element. + arr : StringArray + na_value : Any + The value to use for missing values. By default, this is + the original value (NA). + dtype : Dtype + The result dtype to use. Specifying this aviods an intermediate + object-dtype allocation. + + Returns + ------- + ArrayLike + An ExtensionArray for integer or string dtypes, otherwise + an ndarray. + + """ + from pandas.arrays import IntegerArray, StringArray + + mask = isna(arr) + + assert isinstance(arr, StringArray) + arr = np.asarray(arr) + + if is_integer_dtype(dtype): + na_value_is_na = isna(na_value) + if na_value_is_na: + na_value = 1 + result = lib.map_infer_mask( + arr, + func, + mask.view("uint8"), + convert=False, + na_value=na_value, + dtype=np.dtype("int64"), + ) + + if not na_value_is_na: + mask[:] = False + + return IntegerArray(result, mask) + + elif is_string_dtype(dtype) and not is_object_dtype(dtype): + # i.e. StringDtype + result = lib.map_infer_mask( + arr, func, mask.view("uint8"), convert=False, na_value=na_value + ) + return StringArray(result) + # TODO: BooleanArray + else: + # This is when the result type is object. We reach this when + # -> We know the result type is truly object (e.g. .encode returns bytes + # or .findall returns a list). + # -> We don't know the result type. E.g. `.get` can return anything. + return lib.map_infer_mask(arr, func, mask.view("uint8")) -def _map(f, arr, na_mask=False, na_value=np.nan, dtype=object): +def _map_object(f, arr, na_mask=False, na_value=np.nan, dtype=object): if not len(arr): return np.ndarray(0, dtype=dtype) @@ -143,7 +221,7 @@ def g(x): except (TypeError, AttributeError): return na_value - return _map(g, arr, dtype=dtype) + return _map_object(g, arr, dtype=dtype) if na_value is not np.nan: np.putmask(result, mask, na_value) if result.dtype == object: @@ -634,7 +712,7 @@ def str_replace(arr, pat, repl, n=-1, case=None, flags=0, regex=True): raise ValueError("Cannot use a callable replacement when regex=False") f = lambda x: x.replace(pat, repl, n) - return _na_map(f, arr) + return _na_map(f, arr, dtype=str) def str_repeat(arr, repeats): @@ -685,7 +763,7 @@ def scalar_rep(x): except TypeError: return str.__mul__(x, repeats) - return _na_map(scalar_rep, arr) + return _na_map(scalar_rep, arr, dtype=str) else: def rep(x, r): @@ -1150,7 +1228,7 @@ def str_join(arr, sep): 4 NaN dtype: object """ - return _na_map(sep.join, arr) + return _na_map(sep.join, arr, dtype=str) def str_findall(arr, pat, flags=0): @@ -1381,7 +1459,7 @@ def str_pad(arr, width, side="left", fillchar=" "): else: # pragma: no cover raise ValueError("Invalid side") - return _na_map(f, arr) + return _na_map(f, arr, dtype=str) def str_split(arr, pat=None, n=None): @@ -1487,7 +1565,7 @@ def str_slice(arr, start=None, stop=None, step=None): """ obj = slice(start, stop, step) f = lambda x: x[obj] - return _na_map(f, arr) + return _na_map(f, arr, dtype=str) def str_slice_replace(arr, start=None, stop=None, repl=None): @@ -1578,7 +1656,7 @@ def f(x): y += x[local_stop:] return y - return _na_map(f, arr) + return _na_map(f, arr, dtype=str) def str_strip(arr, to_strip=None, side="both"): @@ -1603,7 +1681,7 @@ def str_strip(arr, to_strip=None, side="both"): f = lambda x: x.rstrip(to_strip) else: # pragma: no cover raise ValueError("Invalid side") - return _na_map(f, arr) + return _na_map(f, arr, dtype=str) def str_wrap(arr, width, **kwargs): @@ -1667,7 +1745,7 @@ def str_wrap(arr, width, **kwargs): tw = textwrap.TextWrapper(**kwargs) - return _na_map(lambda s: "\n".join(tw.wrap(s)), arr) + return _na_map(lambda s: "\n".join(tw.wrap(s)), arr, dtype=str) def str_translate(arr, table): @@ -1687,7 +1765,7 @@ def str_translate(arr, table): ------- Series or Index """ - return _na_map(lambda x: x.translate(table), arr) + return _na_map(lambda x: x.translate(table), arr, dtype=str) def str_get(arr, i): @@ -3025,7 +3103,7 @@ def normalize(self, form): import unicodedata f = lambda x: unicodedata.normalize(form, x) - result = _na_map(f, self._parent) + result = _na_map(f, self._parent, dtype=str) return self._wrap_result(result) _shared_docs[ @@ -3223,31 +3301,37 @@ def rindex(self, sub, start=0, end=None): lambda x: x.lower(), name="lower", docstring=_shared_docs["casemethods"] % _doc_args["lower"], + dtype=str, ) upper = _noarg_wrapper( lambda x: x.upper(), name="upper", docstring=_shared_docs["casemethods"] % _doc_args["upper"], + dtype=str, ) title = _noarg_wrapper( lambda x: x.title(), name="title", docstring=_shared_docs["casemethods"] % _doc_args["title"], + dtype=str, ) capitalize = _noarg_wrapper( lambda x: x.capitalize(), name="capitalize", docstring=_shared_docs["casemethods"] % _doc_args["capitalize"], + dtype=str, ) swapcase = _noarg_wrapper( lambda x: x.swapcase(), name="swapcase", docstring=_shared_docs["casemethods"] % _doc_args["swapcase"], + dtype=str, ) casefold = _noarg_wrapper( lambda x: x.casefold(), name="casefold", docstring=_shared_docs["casemethods"] % _doc_args["casefold"], + dtype=str, ) _shared_docs[ diff --git a/pandas/tests/test_strings.py b/pandas/tests/test_strings.py index f68541b620efa..1261c3bbc86db 100644 --- a/pandas/tests/test_strings.py +++ b/pandas/tests/test_strings.py @@ -731,7 +731,10 @@ def test_count(self): tm.assert_series_equal(result, exp) # mixed - mixed = ["a", np.nan, "b", True, datetime.today(), "foo", None, 1, 2.0] + mixed = np.array( + ["a", np.nan, "b", True, datetime.today(), "foo", None, 1, 2.0], + dtype=object, + ) rs = strings.str_count(mixed, "a") xp = np.array([1, np.nan, 0, np.nan, np.nan, 0, np.nan, np.nan, np.nan]) tm.assert_numpy_array_equal(rs, xp) @@ -755,14 +758,14 @@ def test_contains(self): expected = np.array([False, np.nan, False, False, True], dtype=np.object_) tm.assert_numpy_array_equal(result, expected) - values = ["foo", "xyz", "fooommm__foo", "mmm_"] + values = np.array(["foo", "xyz", "fooommm__foo", "mmm_"], dtype=object) result = strings.str_contains(values, pat) expected = np.array([False, False, True, True]) assert result.dtype == np.bool_ tm.assert_numpy_array_equal(result, expected) # case insensitive using regex - values = ["Foo", "xYz", "fOOomMm__fOo", "MMM_"] + values = np.array(["Foo", "xYz", "fOOomMm__fOo", "MMM_"], dtype=object) result = strings.str_contains(values, "FOO|mmm", case=False) expected = np.array([True, False, True, True]) tm.assert_numpy_array_equal(result, expected) @@ -773,7 +776,10 @@ def test_contains(self): tm.assert_numpy_array_equal(result, expected) # mixed - mixed = ["a", np.nan, "b", True, datetime.today(), "foo", None, 1, 2.0] + mixed = np.array( + ["a", np.nan, "b", True, datetime.today(), "foo", None, 1, 2.0], + dtype=object, + ) rs = strings.str_contains(mixed, "o") xp = np.array( [False, np.nan, False, np.nan, np.nan, True, np.nan, np.nan, np.nan], @@ -869,7 +875,10 @@ def test_endswith(self): tm.assert_series_equal(result, exp.fillna(False).astype(bool)) # mixed - mixed = ["a", np.nan, "b", True, datetime.today(), "foo", None, 1, 2.0] + mixed = np.array( + ["a", np.nan, "b", True, datetime.today(), "foo", None, 1, 2.0], + dtype=object, + ) rs = strings.str_endswith(mixed, "f") xp = np.array( [False, np.nan, False, np.nan, np.nan, False, np.nan, np.nan, np.nan], @@ -3489,10 +3498,13 @@ def test_casefold(self): def test_string_array(any_string_method): + method_name, args, kwargs = any_string_method + if method_name == "decode": + pytest.skip("decode requires bytes.") + data = ["a", "bb", np.nan, "ccc"] a = Series(data, dtype=object) b = Series(data, dtype="string") - method_name, args, kwargs = any_string_method expected = getattr(a.str, method_name)(*args, **kwargs) result = getattr(b.str, method_name)(*args, **kwargs) @@ -3503,8 +3515,29 @@ def test_string_array(any_string_method): ): assert result.dtype == "string" result = result.astype(object) + + elif expected.dtype == "float" and expected.isna().any(): + assert result.dtype == "Int64" + result = result.astype("float") + elif isinstance(expected, DataFrame): columns = expected.select_dtypes(include="object").columns assert all(result[columns].dtypes == "string") result[columns] = result[columns].astype(object) tm.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "method,expected", + [ + ("count", [2, None]), + ("find", [0, None]), + ("index", [0, None]), + ("rindex", [2, None]), + ], +) +def test_string_array_numeric_integer_array(method, expected): + s = Series(["aba", None], dtype="string") + result = getattr(s.str, method)("a") + expected = Series(expected, dtype="Int64") + tm.assert_series_equal(result, expected)