From 0e8011500dae04abe652212b771d7fcc74ff75ba Mon Sep 17 00:00:00 2001 From: Todd Jennings Date: Mon, 7 Dec 2020 23:52:56 -0500 Subject: [PATCH] support elementwise operations in many str accessor functions --- xarray/core/accessor_str.py | 1067 ++++++++++------- xarray/tests/test_accessor_str.py | 1786 ++++++++++++++++++++--------- 2 files changed, 1952 insertions(+), 901 deletions(-) diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index 39cbd6d53fc..81ba87b12ca 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -42,7 +42,7 @@ import textwrap from functools import reduce from operator import or_ as set_union -from typing import Any, Callable, Hashable, Mapping, Optional, Pattern, Union +from typing import Any, Callable, Hashable, Mapping, Optional, Pattern, Tuple, Union from unicodedata import normalize import numpy as np @@ -61,8 +61,72 @@ _cpython_optimized_decoders = _cpython_optimized_encoders + ("utf-16", "utf-32") -def _is_str_like(x: Any) -> bool: - return isinstance(x, str) or isinstance(x, bytes) +def _contains_obj_type(*, pat: Any, checker: Any) -> bool: + """Determine if the object fits some rule or is array of objects that do so.""" + if isinstance(checker, type): + targtype = checker + checker = lambda x: isinstance(x, targtype) + + if checker(pat): + return True + + # If it is not an object array it can't contain compiled re + if not getattr(pat, "dtype", "no") == np.object_: + return False + + return _apply_str_ufunc(func=checker, obj=pat).all() + + +def _contains_str_like(pat: Any) -> bool: + """Determine if the object is a str-like or array of str-like.""" + if isinstance(pat, (str, bytes)): + return True + + if not hasattr(pat, "dtype"): + return False + + return pat.dtype.kind == "U" or pat.dtype.kind == "S" + + +def _contains_compiled_re(pat: Any) -> bool: + """Determine if the object is a compiled re or array of compiled re.""" + return _contains_obj_type(pat=pat, checker=re.Pattern) + + +def _contains_callable(pat: Any) -> bool: + """Determine if the object is a callable or array of callables.""" + return _contains_obj_type(pat=pat, checker=callable) + + +def _apply_str_ufunc( + *, + func: Callable, + obj: Any, + dtype: Union[str, np.dtype] = None, + output_core_dims: Union[list, tuple] = ((),), + output_sizes: Mapping[Hashable, int] = None, + func_args: Tuple = (), + func_kwargs: Mapping = {}, +) -> Any: + # TODO handling of na values ? + if dtype is None: + dtype = obj.dtype + + dask_gufunc_kwargs = dict() + if output_sizes is not None: + dask_gufunc_kwargs["output_sizes"] = output_sizes + + return apply_ufunc( + func, + obj, + *func_args, + vectorize=True, + dask="parallelized", + output_dtypes=[dtype], + output_core_dims=output_core_dims, + dask_gufunc_kwargs=dask_gufunc_kwargs, + **func_kwargs, + ) class StringAccessor: @@ -77,12 +141,58 @@ class StringAccessor: array([4, 4, 2, 2, 5]) Dimensions without coordinates: dim_0 - It also implements `+`, `*`, and `%`, which operate as elementwise - versions of the corresponding `str` methods. + It also implements ``+``, ``*``, and ``%``, which operate as elementwise + versions of the corresponding ``str`` methods. These will automatically + broadcast for array-like inputs. + + >>> da1 = xr.DataArray(["first", "second", "third"], dims=["X"]) + >>> da2 = xr.DataArray([1, 2, 3], dims=["Y"]) + >>> da1.str + da2 + + array([['first1', 'first2', 'first3'], + ['second1', 'second2', 'second3'], + ['third1', 'third2', 'third3']], dtype='>> da1 = xr.DataArray(["a", "b", "c", "d"], dims=["X"]) + >>> reps = xr.DataArray([3, 4], dims=["Y"]) + >>> da1.str * reps + + array([['aaa', 'aaaa'], + ['bbb', 'bbbb'], + ['ccc', 'cccc'], + ['ddd', 'dddd']], dtype='>> da1 = xr.DataArray(["%s_%s", "%s-%s", "%s|%s"], dims=["X"]) + >>> da2 = xr.DataArray([1, 2], dims=["Y"]) + >>> da3 = xr.DataArray([0.1, 0.2], dims=["Z"]) + >>> da1.str % (da2, da3) + + array([[['1_0.1', '1_0.2'], + ['2_0.1', '2_0.2']], + + [['1-0.1', '1-0.2'], + ['2-0.1', '2-0.2']], + + [['1|0.1', '1|0.2'], + ['2|0.1', '2|0.2']]], dtype='>> da1 = xr.DataArray(["%(a)s"], dims=["X"]) + >>> da2 = xr.DataArray([1, 2, 3], dims=["Y"]) + >>> da1 % {"a": da2} + + array(['\\narray([1, 2, 3])\\nDimensions without coordinates: Y'], + dtype=object) + Dimensions without coordinates: X """ __slots__ = ("_obj",) - _pattern_type = type(re.compile("")) def __init__(self, obj): self._obj = obj @@ -103,46 +213,38 @@ def _stringify( def _apply( self, + *, func: Callable, - *args, - obj: Any = None, dtype: Union[str, np.dtype] = None, output_core_dims: Union[list, tuple] = ((),), output_sizes: Mapping[Hashable, int] = None, - **kwargs, + func_args: Tuple = (), + func_kwargs: Mapping = {}, ) -> Any: - # TODO handling of na values ? - if obj is None: - obj = self._obj - if dtype is None: - dtype = obj.dtype - - dask_gufunc_kwargs = dict() - if output_sizes is not None: - dask_gufunc_kwargs["output_sizes"] = output_sizes - - return apply_ufunc( - func, - obj, - *args, - vectorize=True, - dask="parallelized", - output_dtypes=[dtype], + return _apply_str_ufunc( + obj=self._obj, + func=func, + dtype=dtype, output_core_dims=output_core_dims, - dask_gufunc_kwargs=dask_gufunc_kwargs, - **kwargs, + output_sizes=output_sizes, + func_args=func_args, + func_kwargs=func_kwargs, ) def _re_compile( - self, pat: Union[str, bytes, Pattern], flags: int, case: bool = None - ) -> Pattern: - is_compiled_re = isinstance(pat, self._pattern_type) + self, + *, + pat: Union[str, bytes, Pattern, Any], + flags: int = 0, + case: bool = None, + ) -> Union[Pattern, Any]: + is_compiled_re = isinstance(pat, re.Pattern) if is_compiled_re and flags != 0: - raise ValueError("flags cannot be set when pat is a compiled regex") + raise ValueError("Flags cannot be set when pat is a compiled regex.") if is_compiled_re and case is not None: - raise ValueError("case cannot be set when pat is a compiled regex") + raise ValueError("Case cannot be set when pat is a compiled regex.") if is_compiled_re: # no-op, needed to tell mypy this isn't a string @@ -156,8 +258,15 @@ def _re_compile( if not case: flags |= re.IGNORECASE - pat = self._stringify(pat) - return re.compile(pat, flags=flags) + if getattr(pat, "dtype", None) != np.object_: + pat = self._stringify(pat) + func = lambda x: re.compile(x, flags=flags) + if isinstance(pat, np.ndarray): + # apply_ufunc doesn't work for numpy arrays with output object dtypes + func = np.vectorize(func) + return func(pat) + else: + return _apply_str_ufunc(func=func, obj=pat, dtype=np.object_) def len(self) -> Any: """ @@ -167,7 +276,7 @@ def len(self) -> Any: ------- lengths array : array of int """ - return self._apply(len, dtype=int) + return self._apply(func=len, dtype=int) def __getitem__( self, @@ -186,33 +295,39 @@ def __add__( def __mul__( self, - num: int, + num: Union[int, Any], ) -> Any: - if num <= 0: - return self[:0] - if num == 1: - return self._obj.copy() - else: - return self.repeat(num) + return self.repeat(num) def __mod__( self, other: Any, ) -> Any: - return self._apply(lambda x: x % other) + if isinstance(other, dict): + other = {key: self._stringify(val) for key, val in other.items()} + return self._apply(func=lambda x: x % other) + elif isinstance(other, tuple): + other = tuple(self._stringify(x) for x in other) + return self._apply(func=lambda x, *y: x % y, func_args=other) + else: + return self._apply(func=lambda x, y: x % y, func_args=(other,)) def get( self, - i: int, + i: Union[int, Any], default: Union[str, bytes] = "", ) -> Any: """ Extract character number `i` from each string in the array. + If `i` is array-like, they are broadcast against the array and + applied elementwise. + Parameters ---------- - i : int + i : int or array-like of int Position of element to extract. + If array-like, it is broadcast. default : optional Value for out-of-range index. If not specified (None) defaults to an empty string. @@ -221,63 +336,71 @@ def get( ------- items : array of object """ - s = slice(-1, None) if i == -1 else slice(i, i + 1) - def f(x): - item = x[s] + def f(x, iind): + islice = slice(-1, None) if iind == -1 else slice(iind, iind + 1) + item = x[islice] return item if item else default - return self._apply(f) + return self._apply(func=f, func_args=(i,)) def slice( self, - start: int = None, - stop: int = None, - step: int = None, + start: Union[int, Any] = None, + stop: Union[int, Any] = None, + step: Union[int, Any] = None, ) -> Any: """ Slice substrings from each string in the array. + If `start`, `stop`, or 'step` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - start : int, optional + start : int or array-like of int, optional Start position for slice operation. - stop : int, optional + If array-like, it is broadcast. + stop : int or array-like of int, optional Stop position for slice operation. - step : int, optional + If array-like, it is broadcast. + step : int or array-like of int, optional Step size for slice operation. + If array-like, it is broadcast. Returns ------- sliced strings : same type as values """ - s = slice(start, stop, step) - f = lambda x: x[s] - return self._apply(f) + f = lambda x, istart, istop, istep: x[slice(istart, istop, istep)] + return self._apply(func=f, func_args=(start, stop, step)) def slice_replace( self, - start: int = None, - stop: int = None, - repl: Union[str, bytes] = "", + start: Union[int, Any] = None, + stop: Union[int, Any] = None, + repl: Union[str, bytes, Any] = "", ) -> Any: """ Replace a positional slice of a string with another value. + If `start`, `stop`, or 'repl` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - start : int, optional + start : int or array-like of int, optional Left index position to use for the slice. If not specified (None), the slice is unbounded on the left, i.e. slice from the start - of the string. - stop : int, optional + of the string. If array-like, it is broadcast. + stop : int or array-like of int, optional Right index position to use for the slice. If not specified (None), the slice is unbounded on the right, i.e. slice until the - end of the string. - repl : str, optional + end of the string. If array-like, it is broadcast. + repl : str or array-like of str, optional String for replacement. If not specified, the sliced region - is replaced with an empty string. + is replaced with an empty string. If array-like, it is broadcast. Returns ------- @@ -285,25 +408,25 @@ def slice_replace( """ repl = self._stringify(repl) - def f(x): - if len(x[start:stop]) == 0: - local_stop = start + def func(x, istart, istop, irepl): + if len(x[istart:istop]) == 0: + local_stop = istart else: - local_stop = stop + local_stop = istop y = self._stringify("") - if start is not None: - y += x[:start] - y += repl - if stop is not None: + if istart is not None: + y += x[:istart] + y += irepl + if istop is not None: y += x[local_stop:] return y - return self._apply(f) + return self._apply(func=func, func_args=(start, stop, repl)) def cat( self, *others, - sep: Any = "", + sep: Union[str, bytes, Any] = "", ) -> Any: """ Concatenate strings elementwise in the DataArray with other strings. @@ -311,14 +434,15 @@ def cat( The other strings can either be string scalars or other array-like. Dimensions are automatically broadcast together. - An optional separator can also be specified. + An optional separator `sep` can also be specified. If `sep` is + array-like, it is broadcast against the array and applied elementwise. Parameters ---------- *others : str or array-like of str Strings or array-like of strings to concatenate elementwise with the current DataArray. - sep : str or array-like, default `""`. + sep : str or array-like of str, default: "". Seperator to use between strings. It is broadcast in the same way as the other input strings. If array-like, its dimensions will be placed at the end of the output array dimensions. @@ -366,7 +490,6 @@ def cat( ['4 cccc 3.4 test', '4, cccc, 3.4, , test']]], dtype=' Any: """ Concatenate strings in a DataArray along a particular dimension. - An optional separator can also be specified. + An optional separator `sep` can also be specified. If `sep` is + array-like, it is broadcast against the array and applied elementwise. Parameters ---------- - dim : Hashable, optional + dim : hashable, optional Dimension along which the strings should be concatenated. + Only one dimension is allowed at a time. Optional for 0D or 1D DataArrays, required for multidimensional DataArrays. - sep : str or array-like, default `""`. + sep : str or array-like, default: "". Seperator to use between strings. It is broadcast in the same way as the other input strings. If array-like, its dimensions will be placed at the end of the output array dimensions. @@ -433,7 +558,6 @@ def join( ['abcd--abcdef', 'abcd__abcdef']], dtype=' - [['spam and lancelot are like a duck', - 'spam and arthur are like a duck'], + [['spam and lancelot are like a duck', + 'spam and arthur are like a duck'], ['egg and lancelot are like a duck', - 'egg and arthur are like a duck']]], dtype=' Any: """ @@ -546,7 +671,7 @@ def capitalize(self) -> Any: ------- capitalized : same type as values """ - return self._apply(lambda x: x.capitalize()) + return self._apply(func=lambda x: x.capitalize()) def lower(self) -> Any: """ @@ -556,7 +681,7 @@ def lower(self) -> Any: ------- lowerd : same type as values """ - return self._apply(lambda x: x.lower()) + return self._apply(func=lambda x: x.lower()) def swapcase(self) -> Any: """ @@ -566,7 +691,7 @@ def swapcase(self) -> Any: ------- swapcased : same type as values """ - return self._apply(lambda x: x.swapcase()) + return self._apply(func=lambda x: x.swapcase()) def title(self) -> Any: """ @@ -576,7 +701,7 @@ def title(self) -> Any: ------- titled : same type as values """ - return self._apply(lambda x: x.title()) + return self._apply(func=lambda x: x.title()) def upper(self) -> Any: """ @@ -586,7 +711,7 @@ def upper(self) -> Any: ------- uppered : same type as values """ - return self._apply(lambda x: x.upper()) + return self._apply(func=lambda x: x.upper()) def casefold(self) -> Any: """ @@ -601,7 +726,7 @@ def casefold(self) -> Any: ------- casefolded : same type as values """ - return self._apply(lambda x: x.casefold()) + return self._apply(func=lambda x: x.casefold()) def normalize( self, @@ -615,16 +740,15 @@ def normalize( Parameters ---------- - form : {"NFC", "NFKC", "NFD", and "NFKD"} + form : {"NFC", "NFKC", "NFD", "NFKD"} Unicode form. Returns ------- normalized : same type as values - """ - return self._apply(lambda x: normalize(form, x)) + return self._apply(func=lambda x: normalize(form, x)) def isalnum(self) -> Any: """ @@ -635,7 +759,7 @@ def isalnum(self) -> Any: isalnum : array of bool Array of boolean values with the same shape as the original array. """ - return self._apply(lambda x: x.isalnum(), dtype=bool) + return self._apply(func=lambda x: x.isalnum(), dtype=bool) def isalpha(self) -> Any: """ @@ -646,7 +770,7 @@ def isalpha(self) -> Any: isalpha : array of bool Array of boolean values with the same shape as the original array. """ - return self._apply(lambda x: x.isalpha(), dtype=bool) + return self._apply(func=lambda x: x.isalpha(), dtype=bool) def isdecimal(self) -> Any: """ @@ -657,7 +781,7 @@ def isdecimal(self) -> Any: isdecimal : array of bool Array of boolean values with the same shape as the original array. """ - return self._apply(lambda x: x.isdecimal(), dtype=bool) + return self._apply(func=lambda x: x.isdecimal(), dtype=bool) def isdigit(self) -> Any: """ @@ -668,7 +792,7 @@ def isdigit(self) -> Any: isdigit : array of bool Array of boolean values with the same shape as the original array. """ - return self._apply(lambda x: x.isdigit(), dtype=bool) + return self._apply(func=lambda x: x.isdigit(), dtype=bool) def islower(self) -> Any: """ @@ -679,7 +803,7 @@ def islower(self) -> Any: islower : array of bool Array of boolean values with the same shape as the original array. """ - return self._apply(lambda x: x.islower(), dtype=bool) + return self._apply(func=lambda x: x.islower(), dtype=bool) def isnumeric(self) -> Any: """ @@ -690,7 +814,7 @@ def isnumeric(self) -> Any: isnumeric : array of bool Array of boolean values with the same shape as the original array. """ - return self._apply(lambda x: x.isnumeric(), dtype=bool) + return self._apply(func=lambda x: x.isnumeric(), dtype=bool) def isspace(self) -> Any: """ @@ -701,7 +825,7 @@ def isspace(self) -> Any: isspace : array of bool Array of boolean values with the same shape as the original array. """ - return self._apply(lambda x: x.isspace(), dtype=bool) + return self._apply(func=lambda x: x.isspace(), dtype=bool) def istitle(self) -> Any: """ @@ -712,7 +836,7 @@ def istitle(self) -> Any: istitle : array of bool Array of boolean values with the same shape as the original array. """ - return self._apply(lambda x: x.istitle(), dtype=bool) + return self._apply(func=lambda x: x.istitle(), dtype=bool) def isupper(self) -> Any: """ @@ -723,13 +847,13 @@ def isupper(self) -> Any: isupper : array of bool Array of boolean values with the same shape as the original array. """ - return self._apply(lambda x: x.isupper(), dtype=bool) + return self._apply(func=lambda x: x.isupper(), dtype=bool) def count( self, - pat: Union[str, bytes, Pattern], + pat: Union[str, bytes, Pattern, Any], flags: int = 0, - case: bool = True, + case: bool = None, ) -> Any: """ Count occurrences of pattern in each string of the array. @@ -738,15 +862,19 @@ def count( pattern is repeated in each of the string elements of the :class:`~xarray.DataArray`. + The pattern `pat` can either be a single ``str`` or `re.Pattern` or + array-like of ``str`` or `re.Pattern`. If array-like, it is broadcast + against the array and applied elementwise. + Parameters ---------- - pat : str or re.Pattern + pat : str or re.Pattern or array-like of str or re.Pattern A string containing a regular expression or a compiled regular - expression object. + expression object. If array-like, it is broadcast. flags : int, default: 0 Flags to pass through to the re module, e.g. `re.IGNORECASE`. see `compilation-flags `_. - ``0`` means no flags. Flags can be combined with the bitwise or operator `|`. + ``0`` means no flags. Flags can be combined with the bitwise or operator ``|``. Cannot be set if `pat` is a compiled regex. case : bool, default: True If True, case sensitive. @@ -757,22 +885,26 @@ def count( ------- counts : array of int """ - pat = self._re_compile(pat, flags, case) + pat = self._re_compile(pat=pat, flags=flags, case=case) - f = lambda x: len(pat.findall(x)) - return self._apply(f, dtype=int) + func = lambda x, ipat: len(ipat.findall(x)) + return self._apply(func=func, func_args=(pat,), dtype=int) def startswith( self, - pat: Union[str, bytes], + pat: Union[str, bytes, Any], ) -> Any: """ Test if the start of each string in the array matches a pattern. + The pattern `pat` can either be a ``str`` or array-like of ``str``. + If array-like, it will be broadcast and applied elementwise. + Parameters ---------- pat : str Character sequence. Regular expressions are not accepted. + If array-like, it is broadcast. Returns ------- @@ -781,20 +913,24 @@ def startswith( the start of each string element. """ pat = self._stringify(pat) - f = lambda x: x.startswith(pat) - return self._apply(f, dtype=bool) + func = lambda x, y: x.startswith(y) + return self._apply(func=func, func_args=(pat,), dtype=bool) def endswith( self, - pat: Union[str, bytes], + pat: Union[str, bytes, Any], ) -> Any: """ Test if the end of each string in the array matches a pattern. + The pattern `pat` can either be a ``str`` or array-like of ``str``. + If array-like, it will be broadcast and applied elementwise. + Parameters ---------- pat : str Character sequence. Regular expressions are not accepted. + If array-like, it is broadcast. Returns ------- @@ -803,119 +939,150 @@ def endswith( the end of each string element. """ pat = self._stringify(pat) - f = lambda x: x.endswith(pat) - return self._apply(f, dtype=bool) + func = lambda x, y: x.endswith(y) + return self._apply(func=func, func_args=(pat,), dtype=bool) def pad( self, - width: int, + width: Union[int, Any], side: str = "left", - fillchar: Union[str, bytes] = " ", + fillchar: Union[str, bytes, Any] = " ", ) -> Any: """ Pad strings in the array up to width. + If `width` or 'fillchar` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - width : int + width : int or array-like of int Minimum width of resulting string; additional characters will be - filled with character defined in `fillchar`. + filled with character defined in ``fillchar``. + If array-like, it is broadcast. side : {"left", "right", "both"}, default: "left" Side from which to fill resulting string. - fillchar : str, default: " " - Additional character for filling, default is whitespace. + fillchar : str or array-like of str, default: " " + Additional character for filling, default is a space. + If array-like, it is broadcast. Returns ------- filled : same type as values Array with a minimum number of char in each element. """ - width = int(width) - fillchar = self._stringify(fillchar) - if len(fillchar) != 1: - raise TypeError("fillchar must be a character, not str") - if side == "left": - f = lambda s: s.rjust(width, fillchar) + func = self.rjust elif side == "right": - f = lambda s: s.ljust(width, fillchar) + func = self.ljust elif side == "both": - f = lambda s: s.center(width, fillchar) + func = self.center else: # pragma: no cover raise ValueError("Invalid side") - return self._apply(f) + return func(width=width, fillchar=fillchar) + + def _padder( + self, + *, + func: Callable, + width: Union[int, Any], + fillchar: Union[str, bytes, Any] = " ", + ) -> Any: + """ + Wrapper function to handle padding operations + """ + fillchar = self._stringify(fillchar) + + def overfunc(x, iwidth, ifillchar): + if len(ifillchar) != 1: + raise TypeError("fillchar must be a character, not str") + return func(x, int(iwidth), ifillchar) + + return self._apply(func=overfunc, func_args=(width, fillchar)) def center( self, - width: int, - fillchar: Union[str, bytes] = " ", + width: Union[int, Any], + fillchar: Union[str, bytes, Any] = " ", ) -> Any: """ Pad left and right side of each string in the array. + If `width` or 'fillchar` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - width : int + width : int or array-like of int Minimum width of resulting string; additional characters will be - filled with ``fillchar`` - fillchar : str, default: " " - Additional character for filling, default is whitespace + filled with ``fillchar``. If array-like, it is broadcast. + fillchar : str or array-like of str, default: " " + Additional character for filling, default is a space. + If array-like, it is broadcast. Returns ------- filled : same type as values """ - return self.pad(width, side="both", fillchar=fillchar) + func = self._obj.dtype.type.center + return self._padder(func=func, width=width, fillchar=fillchar) def ljust( self, - width: int, - fillchar: Union[str, bytes] = " ", + width: Union[int, Any], + fillchar: Union[str, bytes, Any] = " ", ) -> Any: """ Pad right side of each string in the array. + If `width` or 'fillchar` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - width : int + width : int or array-like of int Minimum width of resulting string; additional characters will be - filled with ``fillchar`` - fillchar : str, default: " " - Additional character for filling, default is whitespace + filled with ``fillchar``. If array-like, it is broadcast. + fillchar : str or array-like of str, default: " " + Additional character for filling, default is a space. + If array-like, it is broadcast. Returns ------- filled : same type as values """ - return self.pad(width, side="right", fillchar=fillchar) + func = self._obj.dtype.type.ljust + return self._padder(func=func, width=width, fillchar=fillchar) def rjust( self, - width: int, - fillchar: Union[str, bytes] = " ", + width: Union[int, Any], + fillchar: Union[str, bytes, Any] = " ", ) -> Any: """ Pad left side of each string in the array. + If `width` or 'fillchar` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - width : int + width : int or array-like of int Minimum width of resulting string; additional characters will be - filled with ``fillchar`` - fillchar : str, default: " " - Additional character for filling, default is whitespace + filled with ``fillchar``. If array-like, it is broadcast. + fillchar : str or array-like of str, default: " " + Additional character for filling, default is a space. + If array-like, it is broadcast. Returns ------- filled : same type as values """ - return self.pad(width, side="left", fillchar=fillchar) + func = self._obj.dtype.type.rjust + return self._padder(func=func, width=width, fillchar=fillchar) - def zfill( - self, - width: int, - ) -> Any: + def zfill(self, width: Union[int, Any]) -> Any: """ Pad each string in the array by prepending '0' characters. @@ -923,22 +1090,25 @@ def zfill( left of the string to reach a total string length `width`. Strings in the array with length greater or equal to `width` are unchanged. + If `width` is array-like, it is broadcast against the array and applied + elementwise. + Parameters ---------- - width : int + width : int or array-like of int Minimum length of resulting string; strings with length less - than `width` be prepended with '0' characters. + than `width` be prepended with '0' characters. If array-like, it is broadcast. Returns ------- filled : same type as values """ - return self.pad(width, side="left", fillchar="0") + return self.rjust(width, fillchar="0") def contains( self, - pat: Union[str, bytes, Pattern], - case: bool = True, + pat: Union[str, bytes, Pattern, Any], + case: bool = None, flags: int = 0, regex: bool = True, ) -> Any: @@ -948,11 +1118,15 @@ def contains( Return boolean array based on whether a given pattern or regex is contained within a string of the array. + The pattern `pat` can either be a single ``str`` or `re.Pattern` or + array-like of ``str`` or `re.Pattern`. If array-like, it is broadcast + against the array and applied elementwise. + Parameters ---------- - pat : str or re.Pattern + pat : str or re.Pattern or array-like of str or re.Pattern Character sequence, a string containing a regular expression, - or a compiled regular expression object. + or a compiled regular expression object. If array-like, it is broadcast. case : bool, default: True If True, case sensitive. Cannot be set if `pat` is a compiled regex. @@ -960,7 +1134,7 @@ def contains( flags : int, default: 0 Flags to pass through to the re module, e.g. `re.IGNORECASE`. see `compilation-flags `_. - ``0`` means no flags. Flags can be combined with the bitwise or operator `|`. + ``0`` means no flags. Flags can be combined with the bitwise or operator ``|``. Cannot be set if `pat` is a compiled regex. regex : bool, default: True If True, assumes the pat is a regular expression. @@ -974,42 +1148,54 @@ def contains( given pattern is contained within the string of each element of the array. """ - is_compiled_re = isinstance(pat, self._pattern_type) + is_compiled_re = _contains_compiled_re(pat) if is_compiled_re and not regex: raise ValueError( "Must use regular expression matching for regular expression object." ) if regex: - pat = self._re_compile(pat, flags, case) - if pat.groups > 0: # pragma: no cover - raise ValueError("This pattern has match groups.") + if not is_compiled_re: + pat = self._re_compile(pat=pat, flags=flags, case=case) + + def func(x, ipat): + if ipat.groups > 0: # pragma: no cover + raise ValueError("This pattern has match groups.") + return bool(ipat.search(x)) - f = lambda x: bool(pat.search(x)) else: pat = self._stringify(pat) if case or case is None: - f = lambda x: pat in x + func = lambda x, ipat: ipat in x + elif self._obj.dtype.char == "U": + uppered = self._obj.str.casefold() + uppat = StringAccessor(pat).casefold() + return uppered.str.contains(uppat, regex=False) else: uppered = self._obj.str.upper() - return uppered.str.contains(pat.upper(), regex=False) + uppat = StringAccessor(pat).upper() + return uppered.str.contains(uppat, regex=False) - return self._apply(f, dtype=bool) + return self._apply(func=func, func_args=(pat,), dtype=bool) def match( self, - pat: Union[str, bytes, Pattern], - case: bool = True, + pat: Union[str, bytes, Pattern, Any], + case: bool = None, flags: int = 0, ) -> Any: """ Determine if each string in the array matches a regular expression. + The pattern `pat` can either be a single ``str`` or `re.Pattern` or + array-like of ``str`` or `re.Pattern`. If array-like, it is broadcast + against the array and applied elementwise. + Parameters ---------- - pat : str or re.Pattern + pat : str or re.Pattern or array-like of str or re.Pattern A string containing a regular expression or - a compiled regular expression object. + a compiled regular expression object. If array-like, it is broadcast. case : bool, default: True If True, case sensitive. Cannot be set if `pat` is a compiled regex. @@ -1017,21 +1203,21 @@ def match( flags : int, default: 0 Flags to pass through to the re module, e.g. `re.IGNORECASE`. see `compilation-flags `_. - ``0`` means no flags. Flags can be combined with the bitwise or operator `|`. + ``0`` means no flags. Flags can be combined with the bitwise or operator ``|``. Cannot be set if `pat` is a compiled regex. Returns ------- matched : array of bool """ - pat = self._re_compile(pat, flags, case) + pat = self._re_compile(pat=pat, flags=flags, case=case) - f = lambda x: bool(pat.match(x)) - return self._apply(f, dtype=bool) + func = lambda x, ipat: bool(ipat.match(x)) + return self._apply(func=func, func_args=(pat,), dtype=bool) def strip( self, - to_strip: Union[str, bytes] = None, + to_strip: Union[str, bytes, Any] = None, side: str = "both", ) -> Any: """ @@ -1040,13 +1226,16 @@ def strip( Strip whitespaces (including newlines) or a set of specified characters from each string in the array from left and/or right sides. + `to_strip` can either be a ``str`` or array-like of ``str``. + If array-like, it will be broadcast and applied elementwise. + Parameters ---------- - to_strip : str or None, default: None + to_strip : str or array-like of str or None, default: None Specifying the set of characters to be removed. All combinations of this set of characters will be stripped. - If None then whitespaces are removed. - side : {"left", "right", "both"}, default: "left" + If None then whitespaces are removed. If array-like, it is broadcast. + side : {"left", "right", "both"}, default: "both" Side from which to strip. Returns @@ -1057,19 +1246,19 @@ def strip( to_strip = self._stringify(to_strip) if side == "both": - f = lambda x: x.strip(to_strip) + func = lambda x, y: x.strip(y) elif side == "left": - f = lambda x: x.lstrip(to_strip) + func = lambda x, y: x.lstrip(y) elif side == "right": - f = lambda x: x.rstrip(to_strip) + func = lambda x, y: x.rstrip(y) else: # pragma: no cover raise ValueError("Invalid side") - return self._apply(f) + return self._apply(func=func, func_args=(to_strip,)) def lstrip( self, - to_strip: Union[str, bytes] = None, + to_strip: Union[str, bytes, Any] = None, ) -> Any: """ Remove leading characters. @@ -1077,12 +1266,15 @@ def lstrip( Strip whitespaces (including newlines) or a set of specified characters from each string in the array from the left side. + `to_strip` can either be a ``str`` or array-like of ``str``. + If array-like, it will be broadcast and applied elementwise. + Parameters ---------- - to_strip : str or None, default: None + to_strip : str or array-like of str or None, default: None Specifying the set of characters to be removed. All combinations of this set of characters will be stripped. - If None then whitespaces are removed. + If None then whitespaces are removed. If array-like, it is broadcast. Returns ------- @@ -1092,7 +1284,7 @@ def lstrip( def rstrip( self, - to_strip: Union[str, bytes] = None, + to_strip: Union[str, bytes, Any] = None, ) -> Any: """ Remove trailing characters. @@ -1100,12 +1292,15 @@ def rstrip( Strip whitespaces (including newlines) or a set of specified characters from each string in the array from the right side. + `to_strip` can either be a ``str`` or array-like of ``str``. + If array-like, it will be broadcast and applied elementwise. + Parameters ---------- - to_strip : str or None, default: None + to_strip : str or array-like of str or None, default: None Specifying the set of characters to be removed. All combinations of this set of characters will be stripped. - If None then whitespaces are removed. + If None then whitespaces are removed. If array-like, it is broadcast. Returns ------- @@ -1115,7 +1310,7 @@ def rstrip( def wrap( self, - width: int, + width: Union[int, Any], **kwargs, ) -> Any: """ @@ -1124,10 +1319,14 @@ def wrap( This method has the same keyword parameters and defaults as :class:`textwrap.TextWrapper`. + If `width` is array-like, it is broadcast against the array and applied + elementwise. + Parameters ---------- - width : int - Maximum line-width + width : int or array-like of int + Maximum line-width. + If array-like, it is broadcast. **kwargs keyword arguments passed into :class:`textwrap.TextWrapper`. @@ -1135,9 +1334,10 @@ def wrap( ------- wrapped : same type as values """ - tw = textwrap.TextWrapper(width=width, **kwargs) - f = lambda x: "\n".join(tw.wrap(x)) - return self._apply(f) + ifunc = lambda x: textwrap.TextWrapper(width=x, **kwargs) + tw = StringAccessor(width)._apply(func=ifunc, dtype=np.object_) + func = lambda x, itw: "\n".join(itw.wrap(x)) + return self._apply(func=func, func_args=(tw,)) def translate( self, @@ -1158,34 +1358,38 @@ def translate( ------- translated : same type as values """ - f = lambda x: x.translate(table) - return self._apply(f) + func = lambda x: x.translate(table) + return self._apply(func=func) def repeat( self, - repeats: int, + repeats: Union[int, Any], ) -> Any: """ - Duplicate each string in the array. + Repeat each string in the array. + + If `repeats` is array-like, it is broadcast against the array and applied + elementwise. Parameters ---------- - repeats : int + repeats : int or array-like of int Number of repetitions. + If array-like, it is broadcast. Returns ------- repeated : same type as values Array of repeated string objects. """ - f = lambda x: repeats * x - return self._apply(f) + func = lambda x, y: x * y + return self._apply(func=func, func_args=(repeats,)) def find( self, - sub: Union[str, bytes], - start: int = 0, - end: int = None, + sub: Union[str, bytes, Any], + start: Union[int, Any] = 0, + end: Union[int, Any] = None, side: str = "left", ) -> Any: """ @@ -1193,14 +1397,20 @@ def find( where the substring is fully contained between [start:end]. Return -1 on failure. + If `start`, `end`, or 'sub` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - sub : str - Substring being searched - start : int - Left edge index - end : int - Right edge index + sub : str or array-like of str + Substring being searched. + If array-like, it is broadcast. + start : int or array-like of int + Left edge index. + If array-like, it is broadcast. + end : int or array-like of int + Right edge index. + If array-like, it is broadcast. side : {"left", "right"}, default: "left" Starting side for search. @@ -1217,32 +1427,34 @@ def find( else: # pragma: no cover raise ValueError("Invalid side") - if end is None: - f = lambda x: getattr(x, method)(sub, start) - else: - f = lambda x: getattr(x, method)(sub, start, end) - - return self._apply(f, dtype=int) + func = lambda x, isub, istart, iend: getattr(x, method)(isub, istart, iend) + return self._apply(func=func, func_args=(sub, start, end), dtype=int) def rfind( self, - sub: Union[str, bytes], - start: int = 0, - end: int = None, + sub: Union[str, bytes, Any], + start: Union[int, Any] = 0, + end: Union[int, Any] = None, ) -> Any: """ Return highest indexes in each strings in the array where the substring is fully contained between [start:end]. Return -1 on failure. + If `start`, `end`, or 'sub` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - sub : str - Substring being searched - start : int - Left edge index - end : int - Right edge index + sub : str or array-like of str + Substring being searched. + If array-like, it is broadcast. + start : int or array-like of int + Left edge index. + If array-like, it is broadcast. + end : int or array-like of int + Right edge index. + If array-like, it is broadcast. Returns ------- @@ -1252,9 +1464,9 @@ def rfind( def index( self, - sub: Union[str, bytes], - start: int = 0, - end: int = None, + sub: Union[str, bytes, Any], + start: Union[int, Any] = 0, + end: Union[int, Any] = None, side: str = "left", ) -> Any: """ @@ -1263,14 +1475,20 @@ def index( ``str.find`` except instead of returning -1, it raises a ValueError when the substring is not found. + If `start`, `end`, or 'sub` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - sub : str - Substring being searched - start : int - Left edge index - end : int - Right edge index + sub : str or array-like of str + Substring being searched. + If array-like, it is broadcast. + start : int or array-like of int + Left edge index. + If array-like, it is broadcast. + end : int or array-like of int + Right edge index. + If array-like, it is broadcast. side : {"left", "right"}, default: "left" Starting side for search. @@ -1292,18 +1510,14 @@ def index( else: # pragma: no cover raise ValueError("Invalid side") - if end is None: - f = lambda x: getattr(x, method)(sub, start) - else: - f = lambda x: getattr(x, method)(sub, start, end) - - return self._apply(f, dtype=int) + func = lambda x, isub, istart, iend: getattr(x, method)(isub, istart, iend) + return self._apply(func=func, func_args=(sub, start, end), dtype=int) def rindex( self, - sub: Union[str, bytes], - start: int = 0, - end: int = None, + sub: Union[str, bytes, Any], + start: Union[int, Any] = 0, + end: Union[int, Any] = None, ) -> Any: """ Return highest indexes in each strings where the substring is @@ -1311,14 +1525,20 @@ def rindex( ``str.rfind`` except instead of returning -1, it raises a ValueError when the substring is not found. + If `start`, `end`, or 'sub` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - sub : str - Substring being searched - start : int - Left edge index - end : int - Right edge index + sub : str or array-like of str + Substring being searched. + If array-like, it is broadcast. + start : int or array-like of int + Left edge index. + If array-like, it is broadcast. + end : int or array-like of int + Right edge index. + If array-like, it is broadcast. Returns ------- @@ -1333,9 +1553,9 @@ def rindex( def replace( self, - pat: Union[str, bytes, Pattern], - repl: Union[str, bytes, Callable], - n: int = -1, + pat: Union[str, bytes, Pattern, Any], + repl: Union[str, bytes, Callable, Any], + n: Union[int, Any] = -1, case: bool = None, flags: int = 0, regex: bool = True, @@ -1343,16 +1563,22 @@ def replace( """ Replace occurrences of pattern/regex in the array with some string. + If `pat`, `repl`, or 'n` is array-like, they are broadcast + against the array and applied elementwise. + Parameters ---------- - pat : str or re.Pattern + pat : str or re.Pattern or array-like of str or re.Pattern String can be a character sequence or regular expression. - repl : str or callable + If array-like, it is broadcast. + repl : str or callable or array-like of str or callable Replacement string or a callable. The callable is passed the regex match object and must return a replacement string to be used. See :func:`re.sub`. - n : int, default: -1 + If array-like, it is broadcast. + n : int or array of int, default: -1 Number of replacements to make from start. Use ``-1`` to replace all. + If array-like, it is broadcast. case : bool, default: True If True, case sensitive. Cannot be set if `pat` is a compiled regex. @@ -1360,7 +1586,7 @@ def replace( flags : int, default: 0 Flags to pass through to the re module, e.g. `re.IGNORECASE`. see `compilation-flags `_. - ``0`` means no flags. Flags can be combined with the bitwise or operator `|`. + ``0`` means no flags. Flags can be combined with the bitwise or operator ``|``. Cannot be set if `pat` is a compiled regex. regex : bool, default: True If True, assumes the passed-in pattern is a regular expression. @@ -1374,13 +1600,12 @@ def replace( A copy of the object with all matching occurrences of `pat` replaced by `repl`. """ - if not _is_str_like(repl) and not callable(repl): # pragma: no cover - raise TypeError("repl must be a string or callable") - - if _is_str_like(repl): + if _contains_str_like(repl): repl = self._stringify(repl) + elif not _contains_callable(repl): # pragma: no cover + raise TypeError("repl must be a string or callable") - is_compiled_re = isinstance(pat, self._pattern_type) + is_compiled_re = _contains_compiled_re(pat) if not regex and is_compiled_re: raise ValueError( "Cannot use a compiled regex as replacement pattern with regex=False" @@ -1390,17 +1615,18 @@ def replace( raise ValueError("Cannot use a callable replacement when regex=False") if regex: - pat = self._re_compile(pat, flags, case) - n = n if n >= 0 else 0 - f = lambda x: pat.sub(repl=repl, string=x, count=n) + pat = self._re_compile(pat=pat, flags=flags, case=case) + func = lambda x, ipat, irepl, i_n: ipat.sub( + repl=irepl, string=x, count=i_n if i_n >= 0 else 0 + ) else: pat = self._stringify(pat) - f = lambda x: x.replace(pat, repl, n) - return self._apply(f) + func = lambda x, ipat, irepl, i_n: x.replace(ipat, irepl, i_n) + return self._apply(func=func, func_args=(pat, repl, n)) def extract( self, - pat: Union[str, bytes, Pattern], + pat: Union[str, bytes, Pattern, Any], dim: Hashable, case: bool = None, flags: int = 0, @@ -1412,12 +1638,15 @@ def extract( For each string in the DataArray, extract groups from the first match of regular expression pat. + If `pat` is array-like, it is broadcast against the array and applied + elementwise. + Parameters ---------- - pat : str or re.Pattern + pat : str or re.Pattern or array-like of str or re.Pattern A string containing a regular expression or a compiled regular - expression object. - dim : hashable or `None` + expression object. If array-like, it is broadcast. + dim : hashable or None Name of the new dimension to store the captured strings in. If None, the pattern must have only one capture group and the resulting DataArray will have the same size as the original. @@ -1428,7 +1657,7 @@ def extract( flags : int, default: 0 Flags to pass through to the re module, e.g. `re.IGNORECASE`. see `compilation-flags `_. - ``0`` means no flags. Flags can be combined with the bitwise or operator `|`. + ``0`` means no flags. Flags can be combined with the bitwise or operator ``|``. Cannot be set if `pat` is a compiled regex. Returns @@ -1440,7 +1669,7 @@ def extract( ValueError `pat` has no capture groups. ValueError - `dim` is `None` and there is more than one capture group. + `dim` is None and there is more than one capture group. ValueError `case` is set when `pat` is a compiled regular expression. KeyError @@ -1487,20 +1716,29 @@ def extract( re.search pandas.Series.str.extract """ - pat = self._re_compile(pat, flags, case) + pat = self._re_compile(pat=pat, flags=flags, case=case) + + if isinstance(pat, re.Pattern): + maxgroups = pat.groups + else: + maxgroups = ( + _apply_str_ufunc(obj=pat, func=lambda x: x.groups, dtype=np.int_) + .max() + .data.tolist() + ) - if pat.groups == 0: + if maxgroups == 0: raise ValueError("No capture groups found in pattern.") - if dim is None and pat.groups != 1: + if dim is None and maxgroups != 1: raise ValueError( - "dim must be specified if more than one capture group is given." + "Dimension must be specified if more than one capture group is given." ) if dim is not None and dim in self._obj.dims: - raise KeyError(f"Dimension {dim} already present in DataArray.") + raise KeyError(f"Dimension '{dim}' already present in DataArray.") - def _get_res_single(val, pat=pat): + def _get_res_single(val, pat): match = pat.search(val) if match is None: return "" @@ -1509,7 +1747,7 @@ def _get_res_single(val, pat=pat): res = "" return res - def _get_res_multi(val, pat=pat): + def _get_res_multi(val, pat): match = pat.search(val) if match is None: return np.array([""], val.dtype) @@ -1518,20 +1756,21 @@ def _get_res_multi(val, pat=pat): return np.array(match, val.dtype) if dim is None: - return self._apply(_get_res_single) + return self._apply(func=_get_res_single, func_args=(pat,)) else: # dtype MUST be object or strings can be truncated # See: https://github.com/numpy/numpy/issues/8352 return self._apply( - _get_res_multi, + func=_get_res_multi, + func_args=(pat,), dtype=np.object_, output_core_dims=[[dim]], - output_sizes={dim: pat.groups}, + output_sizes={dim: maxgroups}, ).astype(self._obj.dtype.kind) def extractall( self, - pat: Union[str, bytes, Pattern], + pat: Union[str, bytes, Pattern, Any], group_dim: Hashable, match_dim: Hashable, case: bool = None, @@ -1546,15 +1785,18 @@ def extractall( Equivalent to applying re.findall() to all the elements in the DataArray and splitting the results across dimensions. + If `pat` is array-like, it is broadcast against the array and applied + elementwise. + Parameters ---------- pat : str or re.Pattern A string containing a regular expression or a compiled regular - expression object. - group_dim: hashable + expression object. If array-like, it is broadcast. + group_dim : hashable Name of the new dimensions corresponding to the capture groups. This dimension is added to the new DataArray first. - match_dim: hashable + match_dim : hashable Name of the new dimensions corresponding to the matches for each group. This dimension is added to the new DataArray second. case : bool, default: True @@ -1564,7 +1806,7 @@ def extractall( flags : int, default: 0 Flags to pass through to the re module, e.g. `re.IGNORECASE`. see `compilation-flags `_. - ``0`` means no flags. Flags can be combined with the bitwise or operator `|`. + ``0`` means no flags. Flags can be combined with the bitwise or operator ``|``. Cannot be set if `pat` is a compiled regex. Returns @@ -1634,7 +1876,6 @@ def extractall( ['', '']]]], dtype=' Any: @@ -1700,11 +1958,14 @@ def findall( If there are multiple capture groups, the lists will be a sequence of lists, each of which contains a sequence of matches. + If `pat` is array-like, it is broadcast against the array and applied + elementwise. + Parameters ---------- pat : str or re.Pattern A string containing a regular expression or a compiled regular - expression object. + expression object. If array-like, it is broadcast. case : bool, default: True If True, case sensitive. Cannot be set if `pat` is a compiled regex. @@ -1712,7 +1973,7 @@ def findall( flags : int, default: 0 Flags to pass through to the re module, e.g. `re.IGNORECASE`. see `compilation-flags `_. - ``0`` means no flags. Flags can be combined with the bitwise or operator `|`. + ``0`` means no flags. Flags can be combined with the bitwise or operator ``|``. Cannot be set if `pat` is a compiled regex. Returns @@ -1765,18 +2026,22 @@ def findall( re.findall pandas.Series.str.findall """ - pat = self._re_compile(pat, flags, case) + pat = self._re_compile(pat=pat, flags=flags, case=case) - if pat.groups == 0: - raise ValueError("No capture groups found in pattern.") + def func(x, ipat): + if ipat.groups == 0: + raise ValueError("No capture groups found in pattern.") + + return ipat.findall(x) - return self._apply(pat.findall, dtype=np.object_) + return self._apply(func=func, func_args=(pat,), dtype=np.object_) def _partitioner( self, + *, func: Callable, dim: Hashable, - sep: Optional[Union[str, bytes]], + sep: Optional[Union[str, bytes, Any]], ) -> Any: """ Implements logic for `partition` and `rpartition`. @@ -1784,19 +2049,20 @@ def _partitioner( sep = self._stringify(sep) if dim is None: - f = lambda x: list(func(x, sep)) - return self._apply(f, dtype=np.object_) + listfunc = lambda x, isep: list(func(x, isep)) + return self._apply(func=listfunc, func_args=(sep,), dtype=np.object_) # _apply breaks on an empty array in this case if not self._obj.size: - return self._obj.copy().expand_dims({dim: 0}, -1) + return self._obj.copy().expand_dims({dim: 0}, axis=-1) - f = lambda x: np.array(func(x, sep), dtype=self._obj.dtype) + arrfunc = lambda x, isep: np.array(func(x, isep), dtype=self._obj.dtype) # dtype MUST be object or strings can be truncated # See: https://github.com/numpy/numpy/issues/8352 return self._apply( - f, + func=arrfunc, + func_args=(sep,), dtype=np.object_, output_core_dims=[[dim]], output_sizes={dim: 3}, @@ -1805,7 +2071,7 @@ def _partitioner( def partition( self, dim: Optional[Hashable], - sep: Union[str, bytes] = " ", + sep: Union[str, bytes, Any] = " ", ) -> Any: """ Split the strings in the DataArray at the first occurrence of separator `sep`. @@ -1816,15 +2082,17 @@ def partition( If the separator is not found, return 3 elements containing the string itself, followed by two empty strings. - This is equivalent to :meth:`str.partion`. + If `sep` is array-like, it is broadcast against the array and applied + elementwise. Parameters ---------- - dim : Hashable or `None` + dim : hashable or None Name for the dimension to place the 3 elements in. - If `None`, place the results as list elements in an object DataArray - sep : str, default `" "` + If `None`, place the results as list elements in an object DataArray. + sep : str, default: " " String to split on. + If array-like, it is broadcast. Returns ------- @@ -1841,7 +2109,7 @@ def partition( def rpartition( self, dim: Optional[Hashable], - sep: Union[str, bytes] = " ", + sep: Union[str, bytes, Any] = " ", ) -> Any: """ Split the strings in the DataArray at the last occurrence of separator `sep`. @@ -1852,15 +2120,17 @@ def rpartition( If the separator is not found, return 3 elements containing two empty strings, followed by the string itself. - This is equivalent to :meth:`str.rpartion`. + If `sep` is array-like, it is broadcast against the array and applied + elementwise. Parameters ---------- - dim : Hashable or `None` + dim : hashable or None Name for the dimension to place the 3 elements in. - If `None`, place the results as list elements in an object DataArray - sep : str, default `" "` + If `None`, place the results as list elements in an object DataArray. + sep : str, default: " " String to split on. + If array-like, it is broadcast. Returns ------- @@ -1876,10 +2146,11 @@ def rpartition( def _splitter( self, + *, func: Callable, pre: bool, dim: Hashable, - sep: Optional[Union[str, bytes]], + sep: Optional[Union[str, bytes, Any]], maxsplit: int, ) -> Any: """ @@ -1889,17 +2160,20 @@ def _splitter( sep = self._stringify(sep) if dim is None: - f = lambda x: func(x, sep, maxsplit) - return self._apply(f, dtype=np.object_) + f_none = lambda x, isep: func(x, isep, maxsplit) + return self._apply(func=f_none, func_args=(sep,), dtype=np.object_) # _apply breaks on an empty array in this case if not self._obj.size: - return self._obj.copy().expand_dims({dim: 0}, -1) + return self._obj.copy().expand_dims({dim: 0}, axis=-1) - f_count = lambda x: max(len(func(x, sep, maxsplit)), 1) - maxsplit = self._apply(f_count, dtype=np.int_).max().data.tolist() - 1 + f_count = lambda x, isep: max(len(func(x, isep, maxsplit)), 1) + maxsplit = ( + self._apply(func=f_count, func_args=(sep,), dtype=np.int_).max().data.item() + - 1 + ) - def _dosplit(mystr, sep=sep, maxsplit=maxsplit, dtype=self._obj.dtype): + def _dosplit(mystr, sep, maxsplit=maxsplit, dtype=self._obj.dtype): res = func(mystr, sep, maxsplit) if len(res) < maxsplit + 1: pad = [""] * (maxsplit + 1 - len(res)) @@ -1912,7 +2186,8 @@ def _dosplit(mystr, sep=sep, maxsplit=maxsplit, dtype=self._obj.dtype): # dtype MUST be object or strings can be truncated # See: https://github.com/numpy/numpy/issues/8352 return self._apply( - _dosplit, + func=_dosplit, + func_args=(sep,), dtype=np.object_, output_core_dims=[[dim]], output_sizes={dim: maxsplit}, @@ -1921,7 +2196,7 @@ def _dosplit(mystr, sep=sep, maxsplit=maxsplit, dtype=self._obj.dtype): def split( self, dim: Optional[Hashable], - sep: Union[str, bytes] = None, + sep: Union[str, bytes, Any] = None, maxsplit: int = -1, ) -> Any: """ @@ -1930,18 +2205,20 @@ def split( Splits the string in the DataArray from the beginning, at the specified delimiter string. - This is equivalent to :meth:`str.split`. + If `sep` is array-like, it is broadcast against the array and applied + elementwise. Parameters ---------- - dim : Hashable or `None` + dim : hashable or None Name for the dimension to place the results in. - If `None`, place the results as list elements in an object DataArray - sep : str, default is split on any whitespace. - String to split on. - maxsplit : int, default -1 (all) + If `None`, place the results as list elements in an object DataArray. + sep : str, default: None + String to split on. If ``None`` (the default), split on any whitespace. + If array-like, it is broadcast. + maxsplit : int, default: -1 Limit number of splits in output, starting from the beginning. - -1 will return all splits. + If -1 (the default), return all splits. Returns ------- @@ -2035,8 +2312,8 @@ def split( def rsplit( self, dim: Optional[Hashable], - sep: Union[str, bytes] = None, - maxsplit: int = -1, + sep: Union[str, bytes, Any] = None, + maxsplit: Union[int, Any] = -1, ) -> Any: """ Split strings in a DataArray around the given separator/delimiter `sep`. @@ -2044,18 +2321,20 @@ def rsplit( Splits the string in the DataArray from the end, at the specified delimiter string. - This is equivalent to :meth:`str.rsplit`. + If `sep` is array-like, it is broadcast against the array and applied + elementwise. Parameters ---------- - dim : Hashable or `None` + dim : hashable or None Name for the dimension to place the results in. If `None`, place the results as list elements in an object DataArray - sep : str, default is split on any whitespace. - String to split on. - maxsplit : int, default -1 (all) + sep : str, default: None + String to split on. If ``None`` (the default), split on any whitespace. + If array-like, it is broadcast. + maxsplit : int, default: -1 Limit number of splits in output, starting from the end. - -1 will return all splits. + If -1 (the default), return all splits. The final number of split values may be less than this if there are no DataArray elements with that many values. @@ -2151,7 +2430,7 @@ def rsplit( def get_dummies( self, dim: Hashable, - sep: Union[str, bytes] = "|", + sep: Union[str, bytes, Any] = "|", ) -> Any: """ Return DataArray of dummy/indicator variables. @@ -2161,12 +2440,16 @@ def get_dummies( and the corresponding element of that dimension is `True` if that result is present and `False` if not. + If `sep` is array-like, it is broadcast against the array and applied + elementwise. + Parameters ---------- - dim : Hashable + dim : hashable Name for the dimension to place the results in. - sep : str, default `"|"`. + sep : str, default: "|". String to split on. + If array-like, it is broadcast. Returns ------- @@ -2205,16 +2488,16 @@ def get_dummies( """ # _apply breaks on an empty array in this case if not self._obj.size: - return self._obj.copy().expand_dims({dim: 0}, -1) + return self._obj.copy().expand_dims({dim: 0}, axis=-1) sep = self._stringify(sep) - f_set = lambda x: set(x.split(sep)) - {self._stringify("")} - setarr = self._apply(f_set, dtype=np.object_) + f_set = lambda x, isep: set(x.split(isep)) - {self._stringify("")} + setarr = self._apply(func=f_set, func_args=(sep,), dtype=np.object_) vals = sorted(reduce(set_union, setarr.data.ravel())) - f = lambda x: np.array([val in x for val in vals], dtype=np.bool_) - res = self._apply( - f, + func = lambda x: np.array([val in x for val in vals], dtype=np.bool_) + res = _apply_str_ufunc( + func=func, obj=setarr, output_core_dims=[[dim]], output_sizes={dim: len(vals)}, @@ -2234,18 +2517,27 @@ def decode( Parameters ---------- encoding : str + The encoding to use. + Please see the Python `codecs `_ documentation for a list + of encodings handlers errors : str, optional + The handler for encoding errors. + Please see the Python `codecs `_ documentation for a list + of error handlers Returns ------- decoded : same type as values + + .. _encodings: https://docs.python.org/3/library/codecs.html#standard-encodings + .. _handlers: https://docs.python.org/3/library/codecs.html#error-handlers """ if encoding in _cpython_optimized_decoders: - f = lambda x: x.decode(encoding, errors) + func = lambda x: x.decode(encoding, errors) else: decoder = codecs.getdecoder(encoding) - f = lambda x: decoder(x, errors)[0] - return self._apply(f, dtype=np.str_) + func = lambda x: decoder(x, errors)[0] + return self._apply(func=func, dtype=np.str_) def encode( self, @@ -2258,15 +2550,24 @@ def encode( Parameters ---------- encoding : str + The encoding to use. + Please see the Python `codecs `_ documentation for a list + of encodings handlers errors : str, optional + The handler for encoding errors. + Please see the Python `codecs `_ documentation for a list + of error handlers Returns ------- encoded : same type as values + + .. _encodings: https://docs.python.org/3/library/codecs.html#standard-encodings + .. _handlers: https://docs.python.org/3/library/codecs.html#error-handlers """ if encoding in _cpython_optimized_encoders: - f = lambda x: x.encode(encoding, errors) + func = lambda x: x.encode(encoding, errors) else: encoder = codecs.getencoder(encoding) - f = lambda x: encoder(x, errors)[0] - return self._apply(f, dtype=np.bytes_) + func = lambda x: encoder(x, errors)[0] + return self._apply(func=func, dtype=np.bytes_) diff --git a/xarray/tests/test_accessor_str.py b/xarray/tests/test_accessor_str.py index bbc9659668c..9bf33893241 100644 --- a/xarray/tests/test_accessor_str.py +++ b/xarray/tests/test_accessor_str.py @@ -69,20 +69,58 @@ def test_dask(): def test_count(dtype): values = xr.DataArray(["foo", "foofoo", "foooofooofommmfoo"]).astype(dtype) - result = values.str.count("f[o]+") + pat_str = dtype(r"f[o]+") + pat_re = re.compile(pat_str) + + result_str = values.str.count(pat_str) + result_re = values.str.count(pat_re) + expected = xr.DataArray([1, 2, 4]) - assert result.dtype == expected.dtype - assert_equal(result, expected) + + assert result_str.dtype == expected.dtype + assert result_re.dtype == expected.dtype + assert_equal(result_str, expected) + assert_equal(result_re, expected) + + +def test_count_broadcast(dtype): + values = xr.DataArray(["foo", "foofoo", "foooofooofommmfoo"]).astype(dtype) + pat_str = np.array([r"f[o]+", r"o", r"m"]).astype(dtype) + pat_re = np.array([re.compile(x) for x in pat_str]) + + result_str = values.str.count(pat_str) + result_re = values.str.count(pat_re) + + expected = xr.DataArray([1, 4, 3]) + + assert result_str.dtype == expected.dtype + assert result_re.dtype == expected.dtype + assert_equal(result_str, expected) + assert_equal(result_re, expected) def test_contains(dtype): values = xr.DataArray(["Foo", "xYz", "fOOomMm__fOo", "MMM_"]).astype(dtype) # case insensitive using regex - result = values.str.contains("FOO|mmm", case=False) + pat = values.dtype.type("FOO|mmm") + result = values.str.contains(pat, case=False) expected = xr.DataArray([True, False, True, True]) assert result.dtype == expected.dtype assert_equal(result, expected) + result = values.str.contains(re.compile(pat, flags=re.IGNORECASE)) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # case sensitive using regex + pat = values.dtype.type("Foo|mMm") + result = values.str.contains(pat) + expected = xr.DataArray([True, False, True, False]) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.contains(re.compile(pat)) + assert result.dtype == expected.dtype + assert_equal(result, expected) # case insensitive without regex result = values.str.contains("foo", regex=False, case=False) @@ -90,6 +128,87 @@ def test_contains(dtype): assert result.dtype == expected.dtype assert_equal(result, expected) + # case sensitive without regex + result = values.str.contains("fO", regex=False, case=True) + expected = xr.DataArray([False, False, True, False]) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # regex regex=False + pat_re = re.compile("(/w+)") + with pytest.raises( + ValueError, + match="Must use regular expression matching for regular expression object.", + ): + values.str.contains(pat_re, regex=False) + + +def test_contains_broadcast(dtype): + values = xr.DataArray(["Foo", "xYz", "fOOomMm__fOo", "MMM_"], dims="X").astype( + dtype + ) + pat_str = xr.DataArray(["FOO|mmm", "Foo", "MMM"], dims="Y").astype(dtype) + pat_re = xr.DataArray([re.compile(x) for x in pat_str.data], dims="Y") + + # case insensitive using regex + result = values.str.contains(pat_str, case=False) + expected = xr.DataArray( + [ + [True, True, False], + [False, False, False], + [True, True, True], + [True, False, True], + ], + dims=["X", "Y"], + ) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # case sensitive using regex + result = values.str.contains(pat_str) + expected = xr.DataArray( + [ + [False, True, False], + [False, False, False], + [False, False, False], + [False, False, True], + ], + dims=["X", "Y"], + ) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.contains(pat_re) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # case insensitive without regex + result = values.str.contains(pat_str, regex=False, case=False) + expected = xr.DataArray( + [ + [False, True, False], + [False, False, False], + [False, True, True], + [False, False, True], + ], + dims=["X", "Y"], + ) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # case insensitive with regex + result = values.str.contains(pat_str, regex=False, case=True) + expected = xr.DataArray( + [ + [False, True, False], + [False, False, False], + [False, False, False], + [False, False, True], + ], + dims=["X", "Y"], + ) + assert result.dtype == expected.dtype + assert_equal(result, expected) + def test_starts_ends_with(dtype): values = xr.DataArray(["om", "foo_nom", "nom", "bar_foo", "foo"]).astype(dtype) @@ -105,15 +224,37 @@ def test_starts_ends_with(dtype): assert_equal(result, expected) -def test_case_bytes(dtype): - dtype = np.bytes_ - value = xr.DataArray(["SOme wOrd"]).astype(dtype) +def test_starts_ends_with_broadcast(dtype): + values = xr.DataArray( + ["om", "foo_nom", "nom", "bar_foo", "foo_bar"], dims="X" + ).astype(dtype) + pat = xr.DataArray(["foo", "bar"], dims="Y").astype(dtype) - exp_capitalized = xr.DataArray(["Some word"]).astype(dtype) - exp_lowered = xr.DataArray(["some word"]).astype(dtype) - exp_swapped = xr.DataArray(["soME WoRD"]).astype(dtype) - exp_titled = xr.DataArray(["Some Word"]).astype(dtype) - exp_uppered = xr.DataArray(["SOME WORD"]).astype(dtype) + result = values.str.startswith(pat) + expected = xr.DataArray( + [[False, False], [True, False], [False, False], [False, True], [True, False]], + dims=["X", "Y"], + ) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.endswith(pat) + expected = xr.DataArray( + [[False, False], [False, False], [False, False], [True, False], [False, True]], + dims=["X", "Y"], + ) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_case_bytes(): + value = xr.DataArray(["SOme wOrd"]).astype(np.bytes_) + + exp_capitalized = xr.DataArray(["Some word"]).astype(np.bytes_) + exp_lowered = xr.DataArray(["some word"]).astype(np.bytes_) + exp_swapped = xr.DataArray(["soME WoRD"]).astype(np.bytes_) + exp_titled = xr.DataArray(["Some Word"]).astype(np.bytes_) + exp_uppered = xr.DataArray(["SOME WORD"]).astype(np.bytes_) res_capitalized = value.str.capitalize() res_lowered = value.str.lower() @@ -127,31 +268,33 @@ def test_case_bytes(dtype): assert res_titled.dtype == exp_titled.dtype assert res_uppered.dtype == exp_uppered.dtype - assert_equal(value.str.capitalize(), exp_capitalized) - assert_equal(value.str.lower(), exp_lowered) - assert_equal(value.str.swapcase(), exp_swapped) - assert_equal(value.str.title(), exp_titled) - assert_equal(value.str.upper(), exp_uppered) + assert_equal(res_capitalized, exp_capitalized) + assert_equal(res_lowered, exp_lowered) + assert_equal(res_swapped, exp_swapped) + assert_equal(res_titled, exp_titled) + assert_equal(res_uppered, exp_uppered) def test_case_str(): - dtype = np.str_ - # This string includes some unicode characters # that are common case management corner cases - value = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(dtype) - - exp_capitalized = xr.DataArray(["Some word dž ß ᾓ σς ffi⁵å ç ⅰ"]).astype(dtype) - exp_lowered = xr.DataArray(["some word dž ß ᾓ σς ffi⁵å ç ⅰ"]).astype(dtype) - exp_swapped = xr.DataArray(["soME WoRD dž SS ᾛ σς FFI⁵å ç ⅰ"]).astype(dtype) - exp_titled = xr.DataArray(["Some Word Dž Ss ᾛ Σς Ffi⁵Å Ç Ⅰ"]).astype(dtype) - exp_uppered = xr.DataArray(["SOME WORD DŽ SS ἫΙ ΣΣ FFI⁵Å Ç Ⅰ"]).astype(dtype) - exp_casefolded = xr.DataArray(["some word dž ss ἣι σσ ffi⁵å ç ⅰ"]).astype(dtype) + value = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(np.unicode_) + + exp_capitalized = xr.DataArray(["Some word dž ß ᾓ σς ffi⁵å ç ⅰ"]).astype(np.unicode_) + exp_lowered = xr.DataArray(["some word dž ß ᾓ σς ffi⁵å ç ⅰ"]).astype(np.unicode_) + exp_swapped = xr.DataArray(["soME WoRD dž SS ᾛ σς FFI⁵å ç ⅰ"]).astype(np.unicode_) + exp_titled = xr.DataArray(["Some Word Dž Ss ᾛ Σς Ffi⁵Å Ç Ⅰ"]).astype(np.unicode_) + exp_uppered = xr.DataArray(["SOME WORD DŽ SS ἫΙ ΣΣ FFI⁵Å Ç Ⅰ"]).astype(np.unicode_) + exp_casefolded = xr.DataArray(["some word dž ss ἣι σσ ffi⁵å ç ⅰ"]).astype( + np.unicode_ + ) - exp_norm_nfc = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(dtype) - exp_norm_nfkc = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi5Å Ç I"]).astype(dtype) - exp_norm_nfd = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(dtype) - exp_norm_nfkd = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi5Å Ç I"]).astype(dtype) + exp_norm_nfc = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(np.unicode_) + exp_norm_nfkc = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi5Å Ç I"]).astype(np.unicode_) + exp_norm_nfd = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(np.unicode_) + exp_norm_nfkd = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi5Å Ç I"]).astype( + np.unicode_ + ) res_capitalized = value.str.capitalize() res_casefolded = value.str.casefold() @@ -191,35 +334,52 @@ def test_case_str(): def test_replace(dtype): - values = xr.DataArray(["fooBAD__barBAD"]).astype(dtype) + values = xr.DataArray(["fooBAD__barBAD"], dims=["x"]).astype(dtype) result = values.str.replace("BAD[_]*", "") - expected = xr.DataArray(["foobar"]).astype(dtype) + expected = xr.DataArray(["foobar"], dims=["x"]).astype(dtype) assert result.dtype == expected.dtype assert_equal(result, expected) result = values.str.replace("BAD[_]*", "", n=1) - expected = xr.DataArray(["foobarBAD"]).astype(dtype) + expected = xr.DataArray(["foobarBAD"], dims=["x"]).astype(dtype) assert result.dtype == expected.dtype assert_equal(result, expected) - s = xr.DataArray(["A", "B", "C", "Aaba", "Baca", "", "CABA", "dog", "cat"]).astype( + pat = xr.DataArray(["BAD[_]*", "AD[_]*"], dims=["y"]).astype(dtype) + result = values.str.replace(pat, "") + expected = xr.DataArray([["foobar", "fooBbarB"]], dims=["x", "y"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + repl = xr.DataArray(["", "spam"], dims=["y"]).astype(dtype) + result = values.str.replace(pat, repl, n=1) + expected = xr.DataArray([["foobarBAD", "fooBspambarBAD"]], dims=["x", "y"]).astype( dtype ) - result = s.str.replace("A", "YYY") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + values = xr.DataArray( + ["A", "B", "C", "Aaba", "Baca", "", "CABA", "dog", "cat"] + ).astype(dtype) expected = xr.DataArray( ["YYY", "B", "C", "YYYaba", "Baca", "", "CYYYBYYY", "dog", "cat"] ).astype(dtype) + result = values.str.replace("A", "YYY") + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.replace("A", "YYY", regex=False) assert result.dtype == expected.dtype assert_equal(result, expected) - result = s.str.replace("A", "YYY", case=False) + result = values.str.replace("A", "YYY", case=False) expected = xr.DataArray( ["YYY", "B", "C", "YYYYYYbYYY", "BYYYcYYY", "", "CYYYBYYY", "dog", "cYYYt"] ).astype(dtype) assert result.dtype == expected.dtype assert_equal(result, expected) - result = s.str.replace("^.a|dog", "XX-XX ", case=False) + result = values.str.replace("^.a|dog", "XX-XX ", case=False) expected = xr.DataArray( ["A", "B", "C", "XX-XX ba", "XX-XX ca", "", "XX-XX BA", "XX-XX ", "XX-XX t"] ).astype(dtype) @@ -246,6 +406,22 @@ def test_replace_callable(): assert result.dtype == exp.dtype assert_equal(result, exp) + # test broadcast + values = xr.DataArray(["Foo Bar Baz"], dims=["x"]) + pat = r"(?P\w+) (?P\w+) (?P\w+)" + repl = xr.DataArray( + [ + lambda m: m.group("first").swapcase(), + lambda m: m.group("middle").swapcase(), + lambda m: m.group("last").swapcase(), + ], + dims=["Y"], + ) + result = values.str.replace(pat, repl) + exp = xr.DataArray([["fOO", "bAR", "bAZ"]], dims=["x", "Y"]) + assert result.dtype == exp.dtype + assert_equal(result, exp) + def test_replace_unicode(): # flags + unicode @@ -256,18 +432,50 @@ def test_replace_unicode(): assert result.dtype == expected.dtype assert_equal(result, expected) + # broadcast version + values = xr.DataArray([b"abcd,\xc3\xa0".decode("utf-8")], dims=["X"]) + expected = xr.DataArray( + [[b"abcd, \xc3\xa0".decode("utf-8"), b"BAcd,\xc3\xa0".decode("utf-8")]], + dims=["X", "Y"], + ) + pat = xr.DataArray( + [re.compile(r"(?<=\w),(?=\w)", flags=re.UNICODE), r"ab"], dims=["Y"] + ) + repl = xr.DataArray([", ", "BA"], dims=["Y"]) + result = values.str.replace(pat, repl) + assert result.dtype == expected.dtype + assert_equal(result, expected) + def test_replace_compiled_regex(dtype): - values = xr.DataArray(["fooBAD__barBAD"]).astype(dtype) + values = xr.DataArray(["fooBAD__barBAD"], dims=["x"]).astype(dtype) + # test with compiled regex pat = re.compile(dtype("BAD[_]*")) result = values.str.replace(pat, "") - expected = xr.DataArray(["foobar"]).astype(dtype) + expected = xr.DataArray(["foobar"], dims=["x"]).astype(dtype) assert result.dtype == expected.dtype assert_equal(result, expected) result = values.str.replace(pat, "", n=1) - expected = xr.DataArray(["foobarBAD"]).astype(dtype) + expected = xr.DataArray(["foobarBAD"], dims=["x"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # broadcast + pat = xr.DataArray( + [re.compile(dtype("BAD[_]*")), re.compile(dtype("AD[_]*"))], dims=["y"] + ) + result = values.str.replace(pat, "") + expected = xr.DataArray([["foobar", "fooBbarB"]], dims=["x", "y"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + repl = xr.DataArray(["", "spam"], dims=["y"]).astype(dtype) + result = values.str.replace(pat, repl, n=1) + expected = xr.DataArray([["foobarBAD", "fooBspambarBAD"]], dims=["x", "y"]).astype( + dtype + ) assert result.dtype == expected.dtype assert_equal(result, expected) @@ -276,13 +484,19 @@ def test_replace_compiled_regex(dtype): values = xr.DataArray(["fooBAD__barBAD__bad"]).astype(dtype) pat = re.compile(dtype("BAD[_]*")) - with pytest.raises(ValueError, match="flags cannot be set"): + with pytest.raises( + ValueError, match="Flags cannot be set when pat is a compiled regex." + ): result = values.str.replace(pat, "", flags=re.IGNORECASE) - with pytest.raises(ValueError, match="case cannot be set"): + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): result = values.str.replace(pat, "", case=False) - with pytest.raises(ValueError, match="case cannot be set"): + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): result = values.str.replace(pat, "", case=True) # test with callable @@ -297,17 +511,33 @@ def test_replace_compiled_regex(dtype): def test_replace_literal(dtype): # GH16808 literal replace (regex=False vs regex=True) - values = xr.DataArray(["f.o", "foo"]).astype(dtype) - expected = xr.DataArray(["bao", "bao"]).astype(dtype) + values = xr.DataArray(["f.o", "foo"], dims=["X"]).astype(dtype) + expected = xr.DataArray(["bao", "bao"], dims=["X"]).astype(dtype) result = values.str.replace("f.", "ba") assert result.dtype == expected.dtype assert_equal(result, expected) - expected = xr.DataArray(["bao", "foo"]).astype(dtype) + expected = xr.DataArray(["bao", "foo"], dims=["X"]).astype(dtype) result = values.str.replace("f.", "ba", regex=False) assert result.dtype == expected.dtype assert_equal(result, expected) + # Broadcast + pat = xr.DataArray(["f.", ".o"], dims=["yy"]).astype(dtype) + expected = xr.DataArray([["bao", "fba"], ["bao", "bao"]], dims=["X", "yy"]).astype( + dtype + ) + result = values.str.replace(pat, "ba") + assert result.dtype == expected.dtype + assert_equal(result, expected) + + expected = xr.DataArray([["bao", "fba"], ["foo", "foo"]], dims=["X", "yy"]).astype( + dtype + ) + result = values.str.replace(pat, "ba", regex=False) + assert result.dtype == expected.dtype + assert_equal(result, expected) + # Cannot do a literal replace if given a callable repl or compiled # pattern callable_repl = lambda m: m.group(0).swapcase() @@ -323,143 +553,133 @@ def test_replace_literal(dtype): def test_extract_extractall_findall_empty_raises(dtype): - pat_str = r"a_\w+_b_\d+_c_.*" + pat_str = dtype(r".*") pat_re = re.compile(pat_str) - value = xr.DataArray( - [ - ["a_first_b_1_c_de", "a_second_b_22_c_efh", "a_third_b_333_c_hijk"], - [ - "a_fourth_b_4444_c_klmno", - "a_fifth_b_5555_c_opqr", - "a_sixth_b_66666_c_rst", - ], - ], - dims=["X", "Y"], - ).astype(dtype) + value = xr.DataArray([["a"]], dims=["X", "Y"]).astype(dtype) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="No capture groups found in pattern."): value.str.extract(pat=pat_str, dim="ZZ") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="No capture groups found in pattern."): value.str.extract(pat=pat_re, dim="ZZ") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="No capture groups found in pattern."): value.str.extractall(pat=pat_str, group_dim="XX", match_dim="YY") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="No capture groups found in pattern."): value.str.extractall(pat=pat_re, group_dim="XX", match_dim="YY") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="No capture groups found in pattern."): value.str.findall(pat=pat_str) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="No capture groups found in pattern."): value.str.findall(pat=pat_re) def test_extract_multi_None_raises(dtype): - pat_str = r"a_(\w+)_b_(\d+)_c_.*" + pat_str = r"(\w+)_(\d+)" pat_re = re.compile(pat_str) - value = xr.DataArray( - [ - ["a_first_b_1_c_de", "a_second_b_22_c_efh", "a_third_b_333_c_hijk"], - [ - "a_fourth_b_4444_c_klmno", - "a_fifth_b_5555_c_opqr", - "a_sixth_b_66666_c_rst", - ], - ], - dims=["X", "Y"], - ).astype(dtype) + value = xr.DataArray([["a_b"]], dims=["X", "Y"]).astype(dtype) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="Dimension must be specified if more than one capture group is given.", + ): value.str.extract(pat=pat_str, dim=None) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="Dimension must be specified if more than one capture group is given.", + ): value.str.extract(pat=pat_re, dim=None) def test_extract_extractall_findall_case_re_raises(dtype): - pat_str = r"a_\w+_b_\d+_c_.*" + pat_str = r".*" pat_re = re.compile(pat_str) - value = xr.DataArray( - [ - ["a_first_b_1_c_de", "a_second_b_22_c_efh", "a_third_b_333_c_hijk"], - [ - "a_fourth_b_4444_c_klmno", - "a_fifth_b_5555_c_opqr", - "a_sixth_b_66666_c_rst", - ], - ], - dims=["X", "Y"], - ).astype(dtype) + value = xr.DataArray([["a"]], dims=["X", "Y"]).astype(dtype) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): value.str.extract(pat=pat_re, case=True, dim="ZZ") - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): value.str.extract(pat=pat_re, case=False, dim="ZZ") - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): value.str.extractall(pat=pat_re, case=True, group_dim="XX", match_dim="YY") - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): value.str.extractall(pat=pat_re, case=False, group_dim="XX", match_dim="YY") - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): value.str.findall(pat=pat_re, case=True) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Case cannot be set when pat is a compiled regex." + ): value.str.findall(pat=pat_re, case=False) def test_extract_extractall_name_collision_raises(dtype): - pat_str = r"a_(\w+)_b_\d+_c_.*" + pat_str = r"(\w+)" pat_re = re.compile(pat_str) - value = xr.DataArray( - [ - ["a_first_b_1_c_de", "a_second_b_22_c_efh", "a_third_b_333_c_hijk"], - [ - "a_fourth_b_4444_c_klmno", - "a_fifth_b_5555_c_opqr", - "a_sixth_b_66666_c_rst", - ], - ], - dims=["X", "Y"], - ).astype(dtype) + value = xr.DataArray([["a"]], dims=["X", "Y"]).astype(dtype) - with pytest.raises(KeyError): + with pytest.raises(KeyError, match="Dimension X already present in DataArray."): value.str.extract(pat=pat_str, dim="X") - with pytest.raises(KeyError): + with pytest.raises(KeyError, match="Dimension X already present in DataArray."): value.str.extract(pat=pat_re, dim="X") - with pytest.raises(KeyError): + with pytest.raises( + KeyError, match="Group dimension X already present in DataArray." + ): value.str.extractall(pat=pat_str, group_dim="X", match_dim="ZZ") - with pytest.raises(KeyError): + with pytest.raises( + KeyError, match="Group dimension X already present in DataArray." + ): value.str.extractall(pat=pat_re, group_dim="X", match_dim="YY") - with pytest.raises(KeyError): + with pytest.raises( + KeyError, match="Match dimension Y already present in DataArray." + ): value.str.extractall(pat=pat_str, group_dim="XX", match_dim="Y") - with pytest.raises(KeyError): + with pytest.raises( + KeyError, match="Match dimension Y already present in DataArray." + ): value.str.extractall(pat=pat_re, group_dim="XX", match_dim="Y") - with pytest.raises(KeyError): + with pytest.raises( + KeyError, match="Group dimension ZZ is the same as match dimension ZZ." + ): value.str.extractall(pat=pat_str, group_dim="ZZ", match_dim="ZZ") - with pytest.raises(KeyError): + with pytest.raises( + KeyError, match="Group dimension ZZ is the same as match dimension ZZ." + ): value.str.extractall(pat=pat_re, group_dim="ZZ", match_dim="ZZ") def test_extract_single_case(dtype): pat_str = r"(\w+)_Xy_\d*" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str)) + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re) value = xr.DataArray( [ @@ -503,16 +723,16 @@ def test_extract_single_case(dtype): def test_extract_single_nocase(dtype): - pat_str = r"(\w+)_Xy_\d*" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str), flags=re.IGNORECASE) + pat_str = r"(\w+)?_Xy_\d*" + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re, flags=re.IGNORECASE) value = xr.DataArray( [ ["a_Xy_0", "ab_xY_10-bab_Xy_110-baab_Xy_1100", "abc_Xy_01-cbc_Xy_2210"], [ "abcd_Xy_-dcd_Xy_33210-dccd_Xy_332210", - "", + "_Xy_1", "abcdef_Xy_101-fef_Xy_5543210", ], ], @@ -544,8 +764,8 @@ def test_extract_single_nocase(dtype): def test_extract_multi_case(dtype): pat_str = r"(\w+)_Xy_(\d*)" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str)) + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re) value = xr.DataArray( [ @@ -559,7 +779,7 @@ def test_extract_multi_case(dtype): dims=["X", "Y"], ).astype(dtype) - targ = xr.DataArray( + expected = xr.DataArray( [ [["a", "0"], ["bab", "110"], ["abc", "01"]], [["abcd", ""], ["", ""], ["abcdef", "101"]], @@ -571,19 +791,19 @@ def test_extract_multi_case(dtype): res_re = value.str.extract(pat=pat_re, dim="XX") res_str_case = value.str.extract(pat=pat_str, dim="XX", case=True) - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype - assert res_str_case.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype - assert_equal(res_str, targ) - assert_equal(res_re, targ) - assert_equal(res_str_case, targ) + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) def test_extract_multi_nocase(dtype): pat_str = r"(\w+)_Xy_(\d*)" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str), flags=re.IGNORECASE) + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re, flags=re.IGNORECASE) value = xr.DataArray( [ @@ -597,7 +817,7 @@ def test_extract_multi_nocase(dtype): dims=["X", "Y"], ).astype(dtype) - targ = xr.DataArray( + expected = xr.DataArray( [ [["a", "0"], ["ab", "10"], ["abc", "01"]], [["abcd", ""], ["", ""], ["abcdef", "101"]], @@ -608,24 +828,53 @@ def test_extract_multi_nocase(dtype): res_str = value.str.extract(pat=pat_str, dim="XX", case=False) res_re = value.str.extract(pat=pat_re, dim="XX") - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + - assert_equal(res_str, targ) - assert_equal(res_re, targ) +def test_extract_broadcast(dtype): + value = xr.DataArray( + ["a_Xy_0", "ab_xY_10", "abc_Xy_01"], + dims=["X"], + ).astype(dtype) + + pat_str = xr.DataArray( + [r"(\w+)_Xy_(\d*)", r"(\w+)_xY_(\d*)"], + dims=["Y"], + ).astype(dtype) + pat_re = value.str._re_compile(pat=pat_str) + + expected = [ + [["a", "0"], ["", ""]], + [["", ""], ["ab", "10"]], + [["abc", "01"], ["", ""]], + ] + expected = xr.DataArray(expected, dims=["X", "Y", "Zz"]).astype(dtype) + + res_str = value.str.extract(pat=pat_str, dim="Zz") + res_re = value.str.extract(pat=pat_re, dim="Zz") + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) def test_extractall_single_single_case(dtype): pat_str = r"(\w+)_Xy_\d*" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str)) + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re) value = xr.DataArray( [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], dims=["X", "Y"], ).astype(dtype) - targ = xr.DataArray( + expected = xr.DataArray( [[[["a"]], [[""]], [["abc"]]], [[["abcd"]], [[""]], [["abcdef"]]]], dims=["X", "Y", "XX", "YY"], ).astype(dtype) @@ -636,26 +885,26 @@ def test_extractall_single_single_case(dtype): pat=pat_str, group_dim="XX", match_dim="YY", case=True ) - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype - assert res_str_case.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype - assert_equal(res_str, targ) - assert_equal(res_re, targ) - assert_equal(res_str_case, targ) + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) def test_extractall_single_single_nocase(dtype): pat_str = r"(\w+)_Xy_\d*" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str), flags=re.I) + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re, flags=re.I) value = xr.DataArray( [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], dims=["X", "Y"], ).astype(dtype) - targ = xr.DataArray( + expected = xr.DataArray( [[[["a"]], [["ab"]], [["abc"]]], [[["abcd"]], [[""]], [["abcdef"]]]], dims=["X", "Y", "XX", "YY"], ).astype(dtype) @@ -665,17 +914,17 @@ def test_extractall_single_single_nocase(dtype): ) res_re = value.str.extractall(pat=pat_re, group_dim="XX", match_dim="YY") - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype - assert_equal(res_str, targ) - assert_equal(res_re, targ) + assert_equal(res_str, expected) + assert_equal(res_re, expected) def test_extractall_single_multi_case(dtype): pat_str = r"(\w+)_Xy_\d*" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str)) + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re) value = xr.DataArray( [ @@ -689,7 +938,7 @@ def test_extractall_single_multi_case(dtype): dims=["X", "Y"], ).astype(dtype) - targ = xr.DataArray( + expected = xr.DataArray( [ [[["a"], [""], [""]], [["bab"], ["baab"], [""]], [["abc"], ["cbc"], [""]]], [ @@ -707,19 +956,19 @@ def test_extractall_single_multi_case(dtype): pat=pat_str, group_dim="XX", match_dim="YY", case=True ) - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype - assert res_str_case.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype - assert_equal(res_str, targ) - assert_equal(res_re, targ) - assert_equal(res_str_case, targ) + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) def test_extractall_single_multi_nocase(dtype): pat_str = r"(\w+)_Xy_\d*" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str), flags=re.I) + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re, flags=re.I) value = xr.DataArray( [ @@ -733,7 +982,7 @@ def test_extractall_single_multi_nocase(dtype): dims=["X", "Y"], ).astype(dtype) - targ = xr.DataArray( + expected = xr.DataArray( [ [ [["a"], [""], [""]], @@ -754,24 +1003,24 @@ def test_extractall_single_multi_nocase(dtype): ) res_re = value.str.extractall(pat=pat_re, group_dim="XX", match_dim="YY") - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype - assert_equal(res_str, targ) - assert_equal(res_re, targ) + assert_equal(res_str, expected) + assert_equal(res_re, expected) def test_extractall_multi_single_case(dtype): pat_str = r"(\w+)_Xy_(\d*)" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str)) + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re) value = xr.DataArray( [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], dims=["X", "Y"], ).astype(dtype) - targ = xr.DataArray( + expected = xr.DataArray( [ [[["a", "0"]], [["", ""]], [["abc", "01"]]], [[["abcd", ""]], [["", ""]], [["abcdef", "101"]]], @@ -785,26 +1034,26 @@ def test_extractall_multi_single_case(dtype): pat=pat_str, group_dim="XX", match_dim="YY", case=True ) - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype - assert res_str_case.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype - assert_equal(res_str, targ) - assert_equal(res_re, targ) - assert_equal(res_str_case, targ) + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) def test_extractall_multi_single_nocase(dtype): pat_str = r"(\w+)_Xy_(\d*)" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str), flags=re.I) + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re, flags=re.I) value = xr.DataArray( [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], dims=["X", "Y"], ).astype(dtype) - targ = xr.DataArray( + expected = xr.DataArray( [ [[["a", "0"]], [["ab", "10"]], [["abc", "01"]]], [[["abcd", ""]], [["", ""]], [["abcdef", "101"]]], @@ -817,17 +1066,17 @@ def test_extractall_multi_single_nocase(dtype): ) res_re = value.str.extractall(pat=pat_re, group_dim="XX", match_dim="YY") - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype - assert_equal(res_str, targ) - assert_equal(res_re, targ) + assert_equal(res_str, expected) + assert_equal(res_re, expected) def test_extractall_multi_multi_case(dtype): pat_str = r"(\w+)_Xy_(\d*)" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str)) + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re) value = xr.DataArray( [ @@ -841,7 +1090,7 @@ def test_extractall_multi_multi_case(dtype): dims=["X", "Y"], ).astype(dtype) - targ = xr.DataArray( + expected = xr.DataArray( [ [ [["a", "0"], ["", ""], ["", ""]], @@ -863,19 +1112,19 @@ def test_extractall_multi_multi_case(dtype): pat=pat_str, group_dim="XX", match_dim="YY", case=True ) - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype - assert res_str_case.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype - assert_equal(res_str, targ) - assert_equal(res_re, targ) - assert_equal(res_str_case, targ) + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) def test_extractall_multi_multi_nocase(dtype): pat_str = r"(\w+)_Xy_(\d*)" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str), flags=re.I) + pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") + pat_re = re.compile(pat_re, flags=re.I) value = xr.DataArray( [ @@ -889,7 +1138,7 @@ def test_extractall_multi_multi_nocase(dtype): dims=["X", "Y"], ).astype(dtype) - targ = xr.DataArray( + expected = xr.DataArray( [ [ [["a", "0"], ["", ""], ["", ""]], @@ -910,71 +1159,96 @@ def test_extractall_multi_multi_nocase(dtype): ) res_re = value.str.extractall(pat=pat_re, group_dim="XX", match_dim="YY") - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) + + +def test_extractall_broadcast(dtype): + value = xr.DataArray( + ["a_Xy_0", "ab_xY_10", "abc_Xy_01"], + dims=["X"], + ).astype(dtype) - assert_equal(res_str, targ) - assert_equal(res_re, targ) + pat_str = xr.DataArray( + [r"(\w+)_Xy_(\d*)", r"(\w+)_xY_(\d*)"], + dims=["Y"], + ).astype(dtype) + pat_re = value.str._re_compile(pat=pat_str) + + expected = [ + [[["a", "0"]], [["", ""]]], + [[["", ""]], [["ab", "10"]]], + [[["abc", "01"]], [["", ""]]], + ] + expected = xr.DataArray(expected, dims=["X", "Y", "ZX", "ZY"]).astype(dtype) + + res_str = value.str.extractall(pat=pat_str, group_dim="ZX", match_dim="ZY") + res_re = value.str.extractall(pat=pat_re, group_dim="ZX", match_dim="ZY") + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) def test_findall_single_single_case(dtype): pat_str = r"(\w+)_Xy_\d*" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str)) + pat_re = re.compile(dtype(pat_str)) value = xr.DataArray( [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], dims=["X", "Y"], ).astype(dtype) - targ = [[["a"], [], ["abc"]], [["abcd"], [], ["abcdef"]]] - targ = [[[conv(x) for x in y] for y in z] for z in targ] - targ = np.array(targ, dtype=np.object_) - targ = xr.DataArray(targ, dims=["X", "Y"]) + expected = [[["a"], [], ["abc"]], [["abcd"], [], ["abcdef"]]] + expected = [[[dtype(x) for x in y] for y in z] for z in expected] + expected = np.array(expected, dtype=np.object_) + expected = xr.DataArray(expected, dims=["X", "Y"]) res_str = value.str.findall(pat=pat_str) res_re = value.str.findall(pat=pat_re) res_str_case = value.str.findall(pat=pat_str, case=True) - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype - assert res_str_case.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype - assert_equal(res_str, targ) - assert_equal(res_re, targ) - assert_equal(res_str_case, targ) + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) def test_findall_single_single_nocase(dtype): pat_str = r"(\w+)_Xy_\d*" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str), flags=re.I) + pat_re = re.compile(dtype(pat_str), flags=re.I) value = xr.DataArray( [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], dims=["X", "Y"], ).astype(dtype) - targ = [[["a"], ["ab"], ["abc"]], [["abcd"], [], ["abcdef"]]] - targ = [[[conv(x) for x in y] for y in z] for z in targ] - targ = np.array(targ, dtype=np.object_) - print(targ) - targ = xr.DataArray(targ, dims=["X", "Y"]) + expected = [[["a"], ["ab"], ["abc"]], [["abcd"], [], ["abcdef"]]] + expected = [[[dtype(x) for x in y] for y in z] for z in expected] + expected = np.array(expected, dtype=np.object_) + expected = xr.DataArray(expected, dims=["X", "Y"]) res_str = value.str.findall(pat=pat_str, case=False) res_re = value.str.findall(pat=pat_re) - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype - assert_equal(res_str, targ) - assert_equal(res_re, targ) + assert_equal(res_str, expected) + assert_equal(res_re, expected) def test_findall_single_multi_case(dtype): pat_str = r"(\w+)_Xy_\d*" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str)) + pat_re = re.compile(dtype(pat_str)) value = xr.DataArray( [ @@ -988,7 +1262,7 @@ def test_findall_single_multi_case(dtype): dims=["X", "Y"], ).astype(dtype) - targ = [ + expected = [ [["a"], ["bab", "baab"], ["abc", "cbc"]], [ ["abcd", "dcd", "dccd"], @@ -996,27 +1270,26 @@ def test_findall_single_multi_case(dtype): ["abcdef", "fef"], ], ] - targ = [[[conv(x) for x in y] for y in z] for z in targ] - targ = np.array(targ, dtype=np.object_) - targ = xr.DataArray(targ, dims=["X", "Y"]) + expected = [[[dtype(x) for x in y] for y in z] for z in expected] + expected = np.array(expected, dtype=np.object_) + expected = xr.DataArray(expected, dims=["X", "Y"]) res_str = value.str.findall(pat=pat_str) res_re = value.str.findall(pat=pat_re) res_str_case = value.str.findall(pat=pat_str, case=True) - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype - assert res_str_case.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype - assert_equal(res_str, targ) - assert_equal(res_re, targ) - assert_equal(res_str_case, targ) + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) def test_findall_single_multi_nocase(dtype): pat_str = r"(\w+)_Xy_\d*" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str), flags=re.I) + pat_re = re.compile(dtype(pat_str), flags=re.I) value = xr.DataArray( [ @@ -1030,7 +1303,7 @@ def test_findall_single_multi_nocase(dtype): dims=["X", "Y"], ).astype(dtype) - targ = [ + expected = [ [ ["a"], ["ab", "bab", "baab"], @@ -1042,83 +1315,80 @@ def test_findall_single_multi_nocase(dtype): ["abcdef", "fef"], ], ] - targ = [[[conv(x) for x in y] for y in z] for z in targ] - targ = np.array(targ, dtype=np.object_) - targ = xr.DataArray(targ, dims=["X", "Y"]) + expected = [[[dtype(x) for x in y] for y in z] for z in expected] + expected = np.array(expected, dtype=np.object_) + expected = xr.DataArray(expected, dims=["X", "Y"]) res_str = value.str.findall(pat=pat_str, case=False) res_re = value.str.findall(pat=pat_re) - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype - assert_equal(res_str, targ) - assert_equal(res_re, targ) + assert_equal(res_str, expected) + assert_equal(res_re, expected) def test_findall_multi_single_case(dtype): pat_str = r"(\w+)_Xy_(\d*)" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str)) + pat_re = re.compile(dtype(pat_str)) value = xr.DataArray( [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], dims=["X", "Y"], ).astype(dtype) - targ = [ + expected = [ [[["a", "0"]], [], [["abc", "01"]]], [[["abcd", ""]], [], [["abcdef", "101"]]], ] - targ = [[[tuple(conv(x) for x in y) for y in z] for z in w] for w in targ] - targ = np.array(targ, dtype=np.object_) - targ = xr.DataArray(targ, dims=["X", "Y"]) + expected = [[[tuple(dtype(x) for x in y) for y in z] for z in w] for w in expected] + expected = np.array(expected, dtype=np.object_) + expected = xr.DataArray(expected, dims=["X", "Y"]) res_str = value.str.findall(pat=pat_str) res_re = value.str.findall(pat=pat_re) res_str_case = value.str.findall(pat=pat_str, case=True) - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype - assert res_str_case.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype - assert_equal(res_str, targ) - assert_equal(res_re, targ) - assert_equal(res_str_case, targ) + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) def test_findall_multi_single_nocase(dtype): pat_str = r"(\w+)_Xy_(\d*)" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str), flags=re.I) + pat_re = re.compile(dtype(pat_str), flags=re.I) value = xr.DataArray( [["a_Xy_0", "ab_xY_10", "abc_Xy_01"], ["abcd_Xy_", "", "abcdef_Xy_101"]], dims=["X", "Y"], ).astype(dtype) - targ = [ + expected = [ [[["a", "0"]], [["ab", "10"]], [["abc", "01"]]], [[["abcd", ""]], [], [["abcdef", "101"]]], ] - targ = [[[tuple(conv(x) for x in y) for y in z] for z in w] for w in targ] - targ = np.array(targ, dtype=np.object_) - targ = xr.DataArray(targ, dims=["X", "Y"]) + expected = [[[tuple(dtype(x) for x in y) for y in z] for z in w] for w in expected] + expected = np.array(expected, dtype=np.object_) + expected = xr.DataArray(expected, dims=["X", "Y"]) res_str = value.str.findall(pat=pat_str, case=False) res_re = value.str.findall(pat=pat_re) - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype - assert_equal(res_str, targ) - assert_equal(res_re, targ) + assert_equal(res_str, expected) + assert_equal(res_re, expected) def test_findall_multi_multi_case(dtype): pat_str = r"(\w+)_Xy_(\d*)" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str)) + pat_re = re.compile(dtype(pat_str)) value = xr.DataArray( [ @@ -1132,7 +1402,7 @@ def test_findall_multi_multi_case(dtype): dims=["X", "Y"], ).astype(dtype) - targ = [ + expected = [ [ [["a", "0"]], [["bab", "110"], ["baab", "1100"]], @@ -1144,27 +1414,26 @@ def test_findall_multi_multi_case(dtype): [["abcdef", "101"], ["fef", "5543210"]], ], ] - targ = [[[tuple(conv(x) for x in y) for y in z] for z in w] for w in targ] - targ = np.array(targ, dtype=np.object_) - targ = xr.DataArray(targ, dims=["X", "Y"]) + expected = [[[tuple(dtype(x) for x in y) for y in z] for z in w] for w in expected] + expected = np.array(expected, dtype=np.object_) + expected = xr.DataArray(expected, dims=["X", "Y"]) res_str = value.str.findall(pat=pat_str) res_re = value.str.findall(pat=pat_re) res_str_case = value.str.findall(pat=pat_str, case=True) - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype - assert res_str_case.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + assert res_str_case.dtype == expected.dtype - assert_equal(res_str, targ) - assert_equal(res_re, targ) - assert_equal(res_str_case, targ) + assert_equal(res_str, expected) + assert_equal(res_re, expected) + assert_equal(res_str_case, expected) def test_findall_multi_multi_nocase(dtype): pat_str = r"(\w+)_Xy_(\d*)" - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - pat_re = re.compile(conv(pat_str), flags=re.I) + pat_re = re.compile(dtype(pat_str), flags=re.I) value = xr.DataArray( [ @@ -1178,7 +1447,7 @@ def test_findall_multi_multi_nocase(dtype): dims=["X", "Y"], ).astype(dtype) - targ = [ + expected = [ [ [["a", "0"]], [["ab", "10"], ["bab", "110"], ["baab", "1100"]], @@ -1190,18 +1459,45 @@ def test_findall_multi_multi_nocase(dtype): [["abcdef", "101"], ["fef", "5543210"]], ], ] - targ = [[[tuple(conv(x) for x in y) for y in z] for z in w] for w in targ] - targ = np.array(targ, dtype=np.object_) - targ = xr.DataArray(targ, dims=["X", "Y"]) + expected = [[[tuple(dtype(x) for x in y) for y in z] for z in w] for w in expected] + expected = np.array(expected, dtype=np.object_) + expected = xr.DataArray(expected, dims=["X", "Y"]) res_str = value.str.findall(pat=pat_str, case=False) res_re = value.str.findall(pat=pat_re) - assert res_str.dtype == targ.dtype - assert res_re.dtype == targ.dtype + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) - assert_equal(res_str, targ) - assert_equal(res_re, targ) + +def test_findall_broadcast(dtype): + value = xr.DataArray( + ["a_Xy_0", "ab_xY_10", "abc_Xy_01"], + dims=["X"], + ).astype(dtype) + + pat_str = xr.DataArray( + [r"(\w+)_Xy_\d*", r"\w+_Xy_(\d*)"], + dims=["Y"], + ).astype(dtype) + pat_re = value.str._re_compile(pat=pat_str) + + expected = [[["a"], ["0"]], [[], []], [["abc"], ["01"]]] + expected = [[[dtype(x) for x in y] for y in z] for z in expected] + expected = np.array(expected, dtype=np.object_) + expected = xr.DataArray(expected, dims=["X", "Y"]) + + res_str = value.str.findall(pat=pat_str) + res_re = value.str.findall(pat=pat_re) + + assert res_str.dtype == expected.dtype + assert res_re.dtype == expected.dtype + + assert_equal(res_str, expected) + assert_equal(res_re, expected) def test_repeat(dtype): @@ -1219,19 +1515,57 @@ def test_repeat(dtype): assert_equal(result, expected) +def test_repeat_broadcast(dtype): + values = xr.DataArray(["a", "b", "c", "d"], dims=["X"]).astype(dtype) + reps = xr.DataArray([3, 4], dims=["Y"]) + + result = values.str.repeat(reps) + result_mul = values.str * reps + + expected = xr.DataArray( + [["aaa", "aaaa"], ["bbb", "bbbb"], ["ccc", "cccc"], ["ddd", "dddd"]], + dims=["X", "Y"], + ).astype(dtype) + + assert result.dtype == expected.dtype + assert result_mul.dtype == expected.dtype + + assert_equal(result_mul, expected) + assert_equal(result, expected) + + def test_match(dtype): - # New match behavior introduced in 0.13 values = xr.DataArray(["fooBAD__barBAD", "foo"]).astype(dtype) - result = values.str.match(".*(BAD[_]+).*(BAD)") + + # New match behavior introduced in 0.13 + pat = values.dtype.type(".*(BAD[_]+).*(BAD)") + result = values.str.match(pat) expected = xr.DataArray([True, False]) assert result.dtype == expected.dtype assert_equal(result, expected) + result = values.str.match(re.compile(pat)) + assert result.dtype == expected.dtype + assert_equal(result, expected) - values = xr.DataArray(["fooBAD__barBAD", "foo"]).astype(dtype) - result = values.str.match(".*BAD[_]+.*BAD") + # Case-sensitive + pat = values.dtype.type(".*BAD[_]+.*BAD") + result = values.str.match(pat) expected = xr.DataArray([True, False]) assert result.dtype == expected.dtype assert_equal(result, expected) + result = values.str.match(re.compile(pat)) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + # Case-insensitive + pat = values.dtype.type(".*bAd[_]+.*bad") + result = values.str.match(pat, case=False) + expected = xr.DataArray([True, False]) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.match(re.compile(pat, flags=re.IGNORECASE)) + assert result.dtype == expected.dtype + assert_equal(result, expected) def test_empty_str_methods(): @@ -1400,132 +1734,221 @@ def test_len(dtype): def test_find(dtype): values = xr.DataArray(["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF", "XXX"]) values = values.astype(dtype) - result = values.str.find("EF") - expected = xr.DataArray([4, 3, 1, 0, -1]) - assert result.dtype == expected.dtype - assert_equal(result, expected) - expected = xr.DataArray([v.find(dtype("EF")) for v in values.values]) - assert result.dtype == expected.dtype - assert_equal(result, expected) - result = values.str.rfind("EF") - expected = xr.DataArray([4, 5, 7, 4, -1]) - assert result.dtype == expected.dtype - assert_equal(result, expected) - expected = xr.DataArray([v.rfind(dtype("EF")) for v in values.values]) - assert result.dtype == expected.dtype - assert_equal(result, expected) + result_0 = values.str.find("EF") + result_1 = values.str.find("EF", side="left") + expected_0 = xr.DataArray([4, 3, 1, 0, -1]) + expected_1 = xr.DataArray([v.find(dtype("EF")) for v in values.values]) + assert result_0.dtype == expected_0.dtype + assert result_0.dtype == expected_1.dtype + assert result_1.dtype == expected_0.dtype + assert result_1.dtype == expected_1.dtype + assert_equal(result_0, expected_0) + assert_equal(result_0, expected_1) + assert_equal(result_1, expected_0) + assert_equal(result_1, expected_1) + + result_0 = values.str.rfind("EF") + result_1 = values.str.find("EF", side="right") + expected_0 = xr.DataArray([4, 5, 7, 4, -1]) + expected_1 = xr.DataArray([v.rfind(dtype("EF")) for v in values.values]) + assert result_0.dtype == expected_0.dtype + assert result_0.dtype == expected_1.dtype + assert result_1.dtype == expected_0.dtype + assert result_1.dtype == expected_1.dtype + assert_equal(result_0, expected_0) + assert_equal(result_0, expected_1) + assert_equal(result_1, expected_0) + assert_equal(result_1, expected_1) + + result_0 = values.str.find("EF", 3) + result_1 = values.str.find("EF", 3, side="left") + expected_0 = xr.DataArray([4, 3, 7, 4, -1]) + expected_1 = xr.DataArray([v.find(dtype("EF"), 3) for v in values.values]) + assert result_0.dtype == expected_0.dtype + assert result_0.dtype == expected_1.dtype + assert result_1.dtype == expected_0.dtype + assert result_1.dtype == expected_1.dtype + assert_equal(result_0, expected_0) + assert_equal(result_0, expected_1) + assert_equal(result_1, expected_0) + assert_equal(result_1, expected_1) + + result_0 = values.str.rfind("EF", 3) + result_1 = values.str.find("EF", 3, side="right") + expected_0 = xr.DataArray([4, 5, 7, 4, -1]) + expected_1 = xr.DataArray([v.rfind(dtype("EF"), 3) for v in values.values]) + assert result_0.dtype == expected_0.dtype + assert result_0.dtype == expected_1.dtype + assert result_1.dtype == expected_0.dtype + assert result_1.dtype == expected_1.dtype + assert_equal(result_0, expected_0) + assert_equal(result_0, expected_1) + assert_equal(result_1, expected_0) + assert_equal(result_1, expected_1) + + result_0 = values.str.find("EF", 3, 6) + result_1 = values.str.find("EF", 3, 6, side="left") + expected_0 = xr.DataArray([4, 3, -1, 4, -1]) + expected_1 = xr.DataArray([v.find(dtype("EF"), 3, 6) for v in values.values]) + assert result_0.dtype == expected_0.dtype + assert result_0.dtype == expected_1.dtype + assert result_1.dtype == expected_0.dtype + assert result_1.dtype == expected_1.dtype + assert_equal(result_0, expected_0) + assert_equal(result_0, expected_1) + assert_equal(result_1, expected_0) + assert_equal(result_1, expected_1) + + result_0 = values.str.rfind("EF", 3, 6) + result_1 = values.str.find("EF", 3, 6, side="right") + expected_0 = xr.DataArray([4, 3, -1, 4, -1]) + expected_1 = xr.DataArray([v.rfind(dtype("EF"), 3, 6) for v in values.values]) + assert result_0.dtype == expected_0.dtype + assert result_0.dtype == expected_1.dtype + assert result_1.dtype == expected_0.dtype + assert result_1.dtype == expected_1.dtype + assert_equal(result_0, expected_0) + assert_equal(result_0, expected_1) + assert_equal(result_1, expected_0) + assert_equal(result_1, expected_1) + + +def test_find_broadcast(dtype): + values = xr.DataArray( + ["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF", "XXX"], dims=["X"] + ) + values = values.astype(dtype) + sub = xr.DataArray(["EF", "BC", "XX"], dims=["Y"]).astype(dtype) + start = xr.DataArray([0, 7], dims=["Z"]) + end = xr.DataArray([6, 9], dims=["Z"]) - result = values.str.find("EF", 3) - expected = xr.DataArray([4, 3, 7, 4, -1]) - assert result.dtype == expected.dtype - assert_equal(result, expected) - expected = xr.DataArray([v.find(dtype("EF"), 3) for v in values.values]) - assert result.dtype == expected.dtype - assert_equal(result, expected) + result_0 = values.str.find(sub, start, end) + result_1 = values.str.find(sub, start, end, side="left") + expected = xr.DataArray( + [ + [[4, -1], [1, -1], [-1, -1]], + [[3, -1], [0, -1], [-1, -1]], + [[1, 7], [-1, -1], [-1, -1]], + [[0, -1], [-1, -1], [-1, -1]], + [[-1, -1], [-1, -1], [0, -1]], + ], + dims=["X", "Y", "Z"], + ) - result = values.str.rfind("EF", 3) - expected = xr.DataArray([4, 5, 7, 4, -1]) - assert result.dtype == expected.dtype - assert_equal(result, expected) - expected = xr.DataArray([v.rfind(dtype("EF"), 3) for v in values.values]) - assert result.dtype == expected.dtype - assert_equal(result, expected) + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) - result = values.str.find("EF", 3, 6) - expected = xr.DataArray([4, 3, -1, 4, -1]) - assert result.dtype == expected.dtype - assert_equal(result, expected) - expected = xr.DataArray([v.find(dtype("EF"), 3, 6) for v in values.values]) - assert result.dtype == expected.dtype - assert_equal(result, expected) + result_0 = values.str.rfind(sub, start, end) + result_1 = values.str.find(sub, start, end, side="right") + expected = xr.DataArray( + [ + [[4, -1], [1, -1], [-1, -1]], + [[3, -1], [0, -1], [-1, -1]], + [[1, 7], [-1, -1], [-1, -1]], + [[4, -1], [-1, -1], [-1, -1]], + [[-1, -1], [-1, -1], [1, -1]], + ], + dims=["X", "Y", "Z"], + ) - result = values.str.rfind("EF", 3, 6) - expected = xr.DataArray([4, 3, -1, 4, -1]) - assert result.dtype == expected.dtype - assert_equal(result, expected) - xp = xr.DataArray([v.rfind(dtype("EF"), 3, 6) for v in values.values]) - assert result.dtype == xp.dtype - assert_equal(result, xp) + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) def test_index(dtype): s = xr.DataArray(["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF"]).astype(dtype) - result = s.str.index("EF") + result_0 = s.str.index("EF") + result_1 = s.str.index("EF", side="left") expected = xr.DataArray([4, 3, 1, 0]) - assert result.dtype == expected.dtype - assert_equal(result, expected) + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) - result = s.str.rindex("EF") + result_0 = s.str.rindex("EF") + result_1 = s.str.index("EF", side="right") expected = xr.DataArray([4, 5, 7, 4]) - assert result.dtype == expected.dtype - assert_equal(result, expected) + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) - result = s.str.index("EF", 3) + result_0 = s.str.index("EF", 3) + result_1 = s.str.index("EF", 3, side="left") expected = xr.DataArray([4, 3, 7, 4]) - assert result.dtype == expected.dtype - assert_equal(result, expected) + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) - result = s.str.rindex("EF", 3) + result_0 = s.str.rindex("EF", 3) + result_1 = s.str.index("EF", 3, side="right") expected = xr.DataArray([4, 5, 7, 4]) - assert result.dtype == expected.dtype - assert_equal(result, expected) + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) - result = s.str.index("E", 4, 8) + result_0 = s.str.index("E", 4, 8) + result_1 = s.str.index("E", 4, 8, side="left") expected = xr.DataArray([4, 5, 7, 4]) - assert result.dtype == expected.dtype - assert_equal(result, expected) + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) - result = s.str.rindex("E", 0, 5) + result_0 = s.str.rindex("E", 0, 5) + result_1 = s.str.index("E", 0, 5, side="right") expected = xr.DataArray([4, 3, 1, 4]) - assert result.dtype == expected.dtype - assert_equal(result, expected) - - with pytest.raises(ValueError): - result = s.str.index("DE") - + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) -def test_pad(dtype): - values = xr.DataArray(["a", "b", "c", "eeeee"]).astype(dtype) + matchtype = "subsection" if dtype == np.bytes_ else "substring" + with pytest.raises(ValueError, match=f"{matchtype} not found"): + s.str.index("DE") - result = values.str.pad(5, side="left") - expected = xr.DataArray([" a", " b", " c", "eeeee"]).astype(dtype) - assert result.dtype == expected.dtype - assert_equal(result, expected) - result = values.str.pad(5, side="right") - expected = xr.DataArray(["a ", "b ", "c ", "eeeee"]).astype(dtype) - assert result.dtype == expected.dtype - assert_equal(result, expected) - - result = values.str.pad(5, side="both") - expected = xr.DataArray([" a ", " b ", " c ", "eeeee"]).astype(dtype) - assert result.dtype == expected.dtype - assert_equal(result, expected) - - -def test_pad_fillchar(dtype): - values = xr.DataArray(["a", "b", "c", "eeeee"]).astype(dtype) +def test_index_broadcast(dtype): + values = xr.DataArray( + ["ABCDEFGEFDBCA", "BCDEFEFEFDBC", "DEFBCGHIEFBC", "EFGHBCEFBCBCBCEF"], + dims=["X"], + ) + values = values.astype(dtype) + sub = xr.DataArray(["EF", "BC"], dims=["Y"]).astype(dtype) + start = xr.DataArray([0, 6], dims=["Z"]) + end = xr.DataArray([6, 12], dims=["Z"]) - result = values.str.pad(5, side="left", fillchar="X") - expected = xr.DataArray(["XXXXa", "XXXXb", "XXXXc", "eeeee"]).astype(dtype) - assert result.dtype == expected.dtype - assert_equal(result, expected) + result_0 = values.str.index(sub, start, end) + result_1 = values.str.index(sub, start, end, side="left") + expected = xr.DataArray( + [[[4, 7], [1, 10]], [[3, 7], [0, 10]], [[1, 8], [3, 10]], [[0, 6], [4, 8]]], + dims=["X", "Y", "Z"], + ) - result = values.str.pad(5, side="right", fillchar="X") - expected = xr.DataArray(["aXXXX", "bXXXX", "cXXXX", "eeeee"]).astype(dtype) - assert result.dtype == expected.dtype - assert_equal(result, expected) + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) - result = values.str.pad(5, side="both", fillchar="X") - expected = xr.DataArray(["XXaXX", "XXbXX", "XXcXX", "eeeee"]).astype(dtype) - assert result.dtype == expected.dtype - assert_equal(result, expected) + result_0 = values.str.rindex(sub, start, end) + result_1 = values.str.index(sub, start, end, side="right") + expected = xr.DataArray( + [[[4, 7], [1, 10]], [[3, 7], [0, 10]], [[1, 8], [3, 10]], [[0, 6], [4, 10]]], + dims=["X", "Y", "Z"], + ) - msg = "fillchar must be a character, not str" - with pytest.raises(TypeError, match=msg): - result = values.str.pad(5, fillchar="XY") + assert result_0.dtype == expected.dtype + assert result_1.dtype == expected.dtype + assert_equal(result_0, expected) + assert_equal(result_1, expected) def test_translate(): @@ -1537,41 +1960,66 @@ def test_translate(): assert_equal(result, expected) -def test_center_ljust_rjust(dtype): +def test_pad_center_ljust_rjust(dtype): values = xr.DataArray(["a", "b", "c", "eeeee"]).astype(dtype) result = values.str.center(5) expected = xr.DataArray([" a ", " b ", " c ", "eeeee"]).astype(dtype) assert result.dtype == expected.dtype assert_equal(result, expected) + result = values.str.pad(5, side="both") + assert result.dtype == expected.dtype + assert_equal(result, expected) result = values.str.ljust(5) expected = xr.DataArray(["a ", "b ", "c ", "eeeee"]).astype(dtype) assert result.dtype == expected.dtype assert_equal(result, expected) + result = values.str.pad(5, side="right") + assert result.dtype == expected.dtype + assert_equal(result, expected) result = values.str.rjust(5) expected = xr.DataArray([" a", " b", " c", "eeeee"]).astype(dtype) assert result.dtype == expected.dtype assert_equal(result, expected) + result = values.str.pad(5, side="left") + assert result.dtype == expected.dtype + assert_equal(result, expected) -def test_center_ljust_rjust_fillchar(dtype): +def test_pad_center_ljust_rjust_fillchar(dtype): values = xr.DataArray(["a", "bb", "cccc", "ddddd", "eeeeee"]).astype(dtype) + result = values.str.center(5, fillchar="X") - expected = xr.DataArray(["XXaXX", "XXbbX", "Xcccc", "ddddd", "eeeeee"]) - assert result.dtype == expected.astype(dtype).dtype - assert_equal(result, expected.astype(dtype)) + expected = xr.DataArray(["XXaXX", "XXbbX", "Xcccc", "ddddd", "eeeeee"]).astype( + dtype + ) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.pad(5, side="both", fillchar="X") + assert result.dtype == expected.dtype + assert_equal(result, expected) result = values.str.ljust(5, fillchar="X") - expected = xr.DataArray(["aXXXX", "bbXXX", "ccccX", "ddddd", "eeeeee"]) - assert result.dtype == expected.astype(dtype).dtype + expected = xr.DataArray(["aXXXX", "bbXXX", "ccccX", "ddddd", "eeeeee"]).astype( + dtype + ) + assert result.dtype == expected.dtype assert_equal(result, expected.astype(dtype)) + result = values.str.pad(5, side="right", fillchar="X") + assert result.dtype == expected.dtype + assert_equal(result, expected) result = values.str.rjust(5, fillchar="X") - expected = xr.DataArray(["XXXXa", "XXXbb", "Xcccc", "ddddd", "eeeeee"]) - assert result.dtype == expected.astype(dtype).dtype + expected = xr.DataArray(["XXXXa", "XXXbb", "Xcccc", "ddddd", "eeeeee"]).astype( + dtype + ) + assert result.dtype == expected.dtype assert_equal(result, expected.astype(dtype)) + result = values.str.pad(5, side="left", fillchar="X") + assert result.dtype == expected.dtype + assert_equal(result, expected) # If fillchar is not a charatter, normal str raises TypeError # 'aaa'.ljust(5, 'XY') @@ -1587,19 +2035,91 @@ def test_center_ljust_rjust_fillchar(dtype): with pytest.raises(TypeError, match=template.format(dtype="str")): values.str.rjust(5, fillchar="XY") + with pytest.raises(TypeError, match=template.format(dtype="str")): + values.str.pad(5, fillchar="XY") + + +def test_pad_center_ljust_rjust_broadcast(dtype): + values = xr.DataArray(["a", "bb", "cccc", "ddddd", "eeeeee"], dims="X").astype( + dtype + ) + width = xr.DataArray([5, 4], dims="Y") + fillchar = xr.DataArray(["X", "#"], dims="Y").astype(dtype) + + result = values.str.center(width, fillchar=fillchar) + expected = xr.DataArray( + [ + ["XXaXX", "#a##"], + ["XXbbX", "#bb#"], + ["Xcccc", "cccc"], + ["ddddd", "ddddd"], + ["eeeeee", "eeeeee"], + ], + dims=["X", "Y"], + ).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + result = values.str.pad(width, side="both", fillchar=fillchar) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.ljust(width, fillchar=fillchar) + expected = xr.DataArray( + [ + ["aXXXX", "a###"], + ["bbXXX", "bb##"], + ["ccccX", "cccc"], + ["ddddd", "ddddd"], + ["eeeeee", "eeeeee"], + ], + dims=["X", "Y"], + ).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected.astype(dtype)) + result = values.str.pad(width, side="right", fillchar=fillchar) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.rjust(width, fillchar=fillchar) + expected = xr.DataArray( + [ + ["XXXXa", "###a"], + ["XXXbb", "##bb"], + ["Xcccc", "cccc"], + ["ddddd", "ddddd"], + ["eeeeee", "eeeeee"], + ], + dims=["X", "Y"], + ).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected.astype(dtype)) + result = values.str.pad(width, side="left", fillchar=fillchar) + assert result.dtype == expected.dtype + assert_equal(result, expected) + def test_zfill(dtype): values = xr.DataArray(["1", "22", "aaa", "333", "45678"]).astype(dtype) result = values.str.zfill(5) - expected = xr.DataArray(["00001", "00022", "00aaa", "00333", "45678"]) - assert result.dtype == expected.astype(dtype).dtype - assert_equal(result, expected.astype(dtype)) + expected = xr.DataArray(["00001", "00022", "00aaa", "00333", "45678"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) result = values.str.zfill(3) - expected = xr.DataArray(["001", "022", "aaa", "333", "45678"]) - assert result.dtype == expected.astype(dtype).dtype - assert_equal(result, expected.astype(dtype)) + expected = xr.DataArray(["001", "022", "aaa", "333", "45678"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_zfill_broadcast(dtype): + values = xr.DataArray(["1", "22", "aaa", "333", "45678"]).astype(dtype) + width = np.array([4, 5, 0, 3, 8]) + + result = values.str.zfill(width) + expected = xr.DataArray(["0001", "00022", "aaa", "333", "00045678"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) def test_slice(dtype): @@ -1620,6 +2140,17 @@ def test_slice(dtype): raise +def test_slice_broadcast(dtype): + arr = xr.DataArray(["aafootwo", "aabartwo", "aabazqux"]).astype(dtype) + start = xr.DataArray([1, 2, 3]) + stop = 5 + + result = arr.str.slice(start=start, stop=stop) + exp = xr.DataArray(["afoo", "bar", "az"]).astype(dtype) + assert result.dtype == exp.dtype + assert_equal(result, exp) + + def test_slice_replace(dtype): da = lambda x: xr.DataArray(x).astype(dtype) values = da(["short", "a bit longer", "evenlongerthanthat", ""]) @@ -1665,6 +2196,22 @@ def test_slice_replace(dtype): assert_equal(result, expected) +def test_slice_replace_broadcast(dtype): + values = xr.DataArray(["short", "a bit longer", "evenlongerthanthat", ""]).astype( + dtype + ) + start = 2 + stop = np.array([4, 5, None, 7]) + repl = "test" + + expected = xr.DataArray(["shtestt", "a test longer", "evtest", "test"]).astype( + dtype + ) + result = values.str.slice_replace(start, stop, repl) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + def test_strip_lstrip_rstrip(dtype): values = xr.DataArray([" aa ", " bb \n", "cc "]).astype(dtype) @@ -1687,20 +2234,40 @@ def test_strip_lstrip_rstrip(dtype): def test_strip_lstrip_rstrip_args(dtype): values = xr.DataArray(["xxABCxx", "xx BNSD", "LDFJH xx"]).astype(dtype) - rs = values.str.strip("x") - xp = xr.DataArray(["ABC", " BNSD", "LDFJH "]).astype(dtype) - assert rs.dtype == xp.dtype - assert_equal(rs, xp) + result = values.str.strip("x") + expected = xr.DataArray(["ABC", " BNSD", "LDFJH "]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.lstrip("x") + expected = xr.DataArray(["ABCxx", " BNSD", "LDFJH xx"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + result = values.str.rstrip("x") + expected = xr.DataArray(["xxABC", "xx BNSD", "LDFJH "]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + +def test_strip_lstrip_rstrip_broadcast(dtype): + values = xr.DataArray(["xxABCxx", "yy BNSD", "LDFJH zz"]).astype(dtype) + to_strip = xr.DataArray(["x", "y", "z"]).astype(dtype) + + result = values.str.strip(to_strip) + expected = xr.DataArray(["ABC", " BNSD", "LDFJH "]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) - rs = values.str.lstrip("x") - xp = xr.DataArray(["ABCxx", " BNSD", "LDFJH xx"]).astype(dtype) - assert rs.dtype == xp.dtype - assert_equal(rs, xp) + result = values.str.lstrip(to_strip) + expected = xr.DataArray(["ABCxx", " BNSD", "LDFJH zz"]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) - rs = values.str.rstrip("x") - xp = xr.DataArray(["xxABC", "xx BNSD", "LDFJH "]).astype(dtype) - assert rs.dtype == xp.dtype - assert_equal(rs, xp) + result = values.str.rstrip(to_strip) + expected = xr.DataArray(["xxABC", "yy BNSD", "LDFJH "]).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) def test_wrap(): @@ -1800,6 +2367,18 @@ def test_get_default(dtype): assert_equal(result, expected) +def test_get_broadcast(dtype): + values = xr.DataArray(["a_b_c", "c_d_e", "f_g_h"], dims=["X"]).astype(dtype) + inds = xr.DataArray([0, 2], dims=["Y"]) + + result = values.str.get(inds) + expected = xr.DataArray( + [["a", "b"], ["c", "d"], ["f", "g"]], dims=["X", "Y"] + ).astype(dtype) + assert result.dtype == expected.dtype + assert_equal(result, expected) + + def test_encode_decode(): data = xr.DataArray(["a", "b", "a\xe4"]) encoded = data.str.encode("utf-8") @@ -1938,6 +2517,16 @@ def test_partition_comma(dtype): assert_equal(res_rpart_dim, exp_rpart_dim) +def test_partition_empty(dtype): + values = xr.DataArray([], dims=["X"]).astype(dtype) + expected = xr.DataArray(np.zeros((0, 0)), dims=["X", "ZZ"]).astype(dtype) + + res = values.str.partition(sep=", ", dim="ZZ") + + assert res.dtype == expected.dtype + assert_equal(res, expected) + + def test_split_whitespace(dtype): values = xr.DataArray( [ @@ -2003,16 +2592,14 @@ def test_split_whitespace(dtype): [["test0\ntest1\ntest2", "test3"], [], ["abra ka\nda", "bra"]], ] - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - exp_split_none_full = [ - [[conv(x) for x in y] for y in z] for z in exp_split_none_full + [[dtype(x) for x in y] for y in z] for z in exp_split_none_full ] exp_rsplit_none_full = [ - [[conv(x) for x in y] for y in z] for z in exp_rsplit_none_full + [[dtype(x) for x in y] for y in z] for z in exp_rsplit_none_full ] - exp_split_none_1 = [[[conv(x) for x in y] for y in z] for z in exp_split_none_1] - exp_rsplit_none_1 = [[[conv(x) for x in y] for y in z] for z in exp_rsplit_none_1] + exp_split_none_1 = [[[dtype(x) for x in y] for y in z] for z in exp_split_none_1] + exp_rsplit_none_1 = [[[dtype(x) for x in y] for y in z] for z in exp_rsplit_none_1] exp_split_none_full = np.array(exp_split_none_full, dtype=np.object_) exp_rsplit_none_full = np.array(exp_rsplit_none_full, dtype=np.object_) @@ -2143,16 +2730,14 @@ def test_split_comma(dtype): [["test0,test1,test2", "test3"], [""], ["abra,ka,da", "bra"]], ] - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - exp_split_none_full = [ - [[conv(x) for x in y] for y in z] for z in exp_split_none_full + [[dtype(x) for x in y] for y in z] for z in exp_split_none_full ] exp_rsplit_none_full = [ - [[conv(x) for x in y] for y in z] for z in exp_rsplit_none_full + [[dtype(x) for x in y] for y in z] for z in exp_rsplit_none_full ] - exp_split_none_1 = [[[conv(x) for x in y] for y in z] for z in exp_split_none_1] - exp_rsplit_none_1 = [[[conv(x) for x in y] for y in z] for z in exp_rsplit_none_1] + exp_split_none_1 = [[[dtype(x) for x in y] for y in z] for z in exp_split_none_1] + exp_rsplit_none_1 = [[[dtype(x) for x in y] for y in z] for z in exp_rsplit_none_1] exp_split_none_full = np.array(exp_split_none_full, dtype=np.object_) exp_rsplit_none_full = np.array(exp_rsplit_none_full, dtype=np.object_) @@ -2218,6 +2803,80 @@ def test_split_comma(dtype): assert_equal(res_rsplit_none_10, exp_rsplit_none_full) +def test_splitters_broadcast(dtype): + values = xr.DataArray( + ["ab cd,de fg", "spam, ,eggs swallow", "red_blue"], + dims=["X"], + ).astype(dtype) + + sep = xr.DataArray( + [" ", ","], + dims=["Y"], + ).astype(dtype) + + expected_left = xr.DataArray( + [ + [["ab", "cd,de fg"], ["ab cd", "de fg"]], + [["spam,", ",eggs swallow"], ["spam", " ,eggs swallow"]], + [["red_blue", ""], ["red_blue", ""]], + ], + dims=["X", "Y", "ZZ"], + ).astype(dtype) + expected_right = xr.DataArray( + [ + [["ab cd,de", "fg"], ["ab cd", "de fg"]], + [["spam, ,eggs", "swallow"], ["spam, ", "eggs swallow"]], + [["", "red_blue"], ["", "red_blue"]], + ], + dims=["X", "Y", "ZZ"], + ).astype(dtype) + + res_left = values.str.split(dim="ZZ", sep=sep, maxsplit=1) + res_right = values.str.rsplit(dim="ZZ", sep=sep, maxsplit=1) + + # assert res_left.dtype == expected_left.dtype + # assert res_right.dtype == expected_right.dtype + + assert_equal(res_left, expected_left) + assert_equal(res_right, expected_right) + + expected_left = xr.DataArray( + [ + [["ab", " ", "cd,de fg"], ["ab cd", ",", "de fg"]], + [["spam,", " ", ",eggs swallow"], ["spam", ",", " ,eggs swallow"]], + [["red_blue", "", ""], ["red_blue", "", ""]], + ], + dims=["X", "Y", "ZZ"], + ).astype(dtype) + expected_right = xr.DataArray( + [ + [["ab", " ", "cd,de fg"], ["ab cd", ",", "de fg"]], + [["spam,", " ", ",eggs swallow"], ["spam", ",", " ,eggs swallow"]], + [["red_blue", "", ""], ["red_blue", "", ""]], + ], + dims=["X", "Y", "ZZ"], + ).astype(dtype) + + res_left = values.str.partition(dim="ZZ", sep=sep) + res_right = values.str.partition(dim="ZZ", sep=sep) + + # assert res_left.dtype == expected_left.dtype + # assert res_right.dtype == expected_right.dtype + + assert_equal(res_left, expected_left) + assert_equal(res_right, expected_right) + + +def test_split_empty(dtype): + values = xr.DataArray([], dims=["X"]).astype(dtype) + expected = xr.DataArray(np.zeros((0, 0)), dims=["X", "ZZ"]).astype(dtype) + + res = values.str.split(sep=", ", dim="ZZ") + + assert res.dtype == expected.dtype + assert_equal(res, expected) + + def test_get_dummies(dtype): values_line = xr.DataArray( [["a|ab~abc|abc", "ab", "a||abc|abcd"], ["abcd|ab|a", "abc|ab~abc", "|a"]], @@ -2230,7 +2889,7 @@ def test_get_dummies(dtype): vals_line = np.array(["a", "ab", "abc", "abcd", "ab~abc"]).astype(dtype) vals_comma = np.array(["a", "ab", "abc", "abcd", "ab|abc"]).astype(dtype) - targ = [ + expected = [ [ [True, False, True, False, True], [False, True, False, False, False], @@ -2242,10 +2901,10 @@ def test_get_dummies(dtype): [True, False, False, False, False], ], ] - targ = np.array(targ) - targ = xr.DataArray(targ, dims=["X", "Y", "ZZ"]) - targ_line = targ.copy() - targ_comma = targ.copy() + expected = np.array(expected) + expected = xr.DataArray(expected, dims=["X", "Y", "ZZ"]) + targ_line = expected.copy() + targ_comma = expected.copy() targ_line.coords["ZZ"] = vals_line targ_comma.coords["ZZ"] = vals_comma @@ -2262,14 +2921,50 @@ def test_get_dummies(dtype): assert_equal(res_comma, targ_comma) +def test_get_dummies_broadcast(dtype): + values = xr.DataArray( + ["x~x|x~x", "x", "x|x~x", "x~x"], + dims=["X"], + ).astype(dtype) + + sep = xr.DataArray( + ["|", "~"], + dims=["Y"], + ).astype(dtype) + + expected = [ + [[False, False, True], [True, True, False]], + [[True, False, False], [True, False, False]], + [[True, False, True], [True, True, False]], + [[False, False, True], [True, False, False]], + ] + expected = np.array(expected) + expected = xr.DataArray(expected, dims=["X", "Y", "ZZ"]) + expected.coords["ZZ"] = np.array(["x", "x|x", "x~x"]).astype(dtype) + + res = values.str.get_dummies(dim="ZZ", sep=sep) + + assert res.dtype == expected.dtype + + assert_equal(res, expected) + + +def test_get_dummies_empty(dtype): + values = xr.DataArray([], dims=["X"]).astype(dtype) + expected = xr.DataArray(np.zeros((0, 0)), dims=["X", "ZZ"]).astype(dtype) + + res = values.str.get_dummies(dim="ZZ") + + assert res.dtype == expected.dtype + assert_equal(res, expected) + + def test_splitters_empty_str(dtype): values = xr.DataArray( [["", "", ""], ["", "", ""]], dims=["X", "Y"], ).astype(dtype) - conv = {np.unicode_: str, np.bytes_: lambda x: bytes(x, encoding="UTF-8")}[dtype] - targ_partition_dim = xr.DataArray( [ [["", "", ""], ["", "", ""], ["", "", ""]], @@ -2283,7 +2978,7 @@ def test_splitters_empty_str(dtype): [["", "", ""], ["", "", ""], ["", "", "", ""]], ] targ_partition_none = [ - [[conv(x) for x in y] for y in z] for z in targ_partition_none + [[dtype(x) for x in y] for y in z] for z in targ_partition_none ] targ_partition_none = np.array(targ_partition_none, dtype=np.object_) del targ_partition_none[-1, -1][-1] @@ -2339,58 +3034,6 @@ def test_splitters_empty_str(dtype): assert_equal(res_dummies, targ_split_dim) -def test_splitters_empty_array(dtype): - values = xr.DataArray( - [[], []], - dims=["X", "Y"], - ).astype(dtype) - - targ_dim = xr.DataArray( - np.empty([2, 0, 0]), - dims=["X", "Y", "ZZ"], - ).astype(dtype) - targ_none = xr.DataArray( - np.empty([2, 0]), - dims=["X", "Y"], - ).astype(np.object_) - - res_part_dim = values.str.partition(dim="ZZ") - res_rpart_dim = values.str.rpartition(dim="ZZ") - res_part_none = values.str.partition(dim=None) - res_rpart_none = values.str.rpartition(dim=None) - - res_split_dim = values.str.split(dim="ZZ") - res_rsplit_dim = values.str.rsplit(dim="ZZ") - res_split_none = values.str.split(dim=None) - res_rsplit_none = values.str.rsplit(dim=None) - - res_dummies = values.str.get_dummies(dim="ZZ") - - assert res_part_dim.dtype == targ_dim.dtype - assert res_rpart_dim.dtype == targ_dim.dtype - assert res_part_none.dtype == targ_none.dtype - assert res_rpart_none.dtype == targ_none.dtype - - assert res_split_dim.dtype == targ_dim.dtype - assert res_rsplit_dim.dtype == targ_dim.dtype - assert res_split_none.dtype == targ_none.dtype - assert res_rsplit_none.dtype == targ_none.dtype - - assert res_dummies.dtype == targ_dim.dtype - - assert_equal(res_part_dim, targ_dim) - assert_equal(res_rpart_dim, targ_dim) - assert_equal(res_part_none, targ_none) - assert_equal(res_rpart_none, targ_none) - - assert_equal(res_split_dim, targ_dim) - assert_equal(res_rsplit_dim, targ_dim) - assert_equal(res_split_none, targ_none) - assert_equal(res_rsplit_none, targ_none) - - assert_equal(res_dummies, targ_dim) - - def test_cat_str(dtype): values_1 = xr.DataArray( [["a", "bb", "cccc"], ["ddddd", "eeee", "fff"]], @@ -2666,7 +3309,6 @@ def test_cat_broadcast_both(dtype): def test_cat_multi(): - dtype = np.unicode_ values_1 = xr.DataArray( ["11111", "4"], dims=["X"], @@ -2686,9 +3328,9 @@ def test_cat_multi(): sep = xr.DataArray( [" ", ", "], dims=["ZZ"], - ).astype(dtype) + ).astype(np.unicode_) - targ = xr.DataArray( + expected = xr.DataArray( [ [ ["11111 a 3.4 ", "11111, a, 3.4, , "], @@ -2702,12 +3344,27 @@ def test_cat_multi(): ], ], dims=["X", "Y", "ZZ"], - ).astype(dtype) + ).astype(np.unicode_) res = values_1.str.cat(values_2, values_3, values_4, values_5, sep=sep) - assert res.dtype == targ.dtype - assert_equal(res, targ) + assert res.dtype == expected.dtype + assert_equal(res, expected) + + +def test_join_scalar(dtype): + values = xr.DataArray("aaa").astype(dtype) + + targ = xr.DataArray("aaa").astype(dtype) + + res_blank = values.str.join() + res_space = values.str.join(sep=" ") + + assert res_blank.dtype == targ.dtype + assert res_space.dtype == targ.dtype + + assert_identical(res_blank, targ) + assert_identical(res_space, targ) def test_join_vector(dtype): @@ -2723,7 +3380,7 @@ def test_join_vector(dtype): res_blank_y = values.str.join(dim="Y") res_space_none = values.str.join(sep=" ") - res_space_y = values.str.join(sep=" ", dim="Y") + res_space_y = values.str.join(dim="Y", sep=" ") assert res_blank_none.dtype == targ_blank.dtype assert res_blank_y.dtype == targ_blank.dtype @@ -2764,7 +3421,7 @@ def test_join_2d(dtype): res_blank_y = values.str.join(dim="Y") res_space_x = values.str.join(dim="X", sep=" ") - res_space_y = values.str.join(sep=" ", dim="Y") + res_space_y = values.str.join(dim="Y", sep=" ") assert res_blank_x.dtype == targ_blank_x.dtype assert res_blank_y.dtype == targ_blank_y.dtype @@ -2776,7 +3433,9 @@ def test_join_2d(dtype): assert_identical(res_space_x, targ_space_x) assert_identical(res_space_y, targ_space_y) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Dimension must be specified for multidimensional arrays." + ): values.str.join() @@ -2791,23 +3450,22 @@ def test_join_broadcast(dtype): dims=["ZZ"], ).astype(dtype) - targ = xr.DataArray( + expected = xr.DataArray( ["a bb cccc", "a, bb, cccc"], dims=["ZZ"], ).astype(dtype) res = values.str.join(sep=sep) - assert res.dtype == targ.dtype - assert_identical(res, targ) + assert res.dtype == expected.dtype + assert_identical(res, expected) def test_format_scalar(): - dtype = np.unicode_ values = xr.DataArray( ["{}.{Y}.{ZZ}", "{},{},{X},{X}", "{X}-{Y}-{ZZ}"], dims=["X"], - ).astype(dtype) + ).astype(np.unicode_) pos0 = 1 pos1 = 1.2 @@ -2817,23 +3475,22 @@ def test_format_scalar(): ZZ = None W = "NO!" - targ = xr.DataArray( + expected = xr.DataArray( ["1.X.None", "1,1.2,'test','test'", "'test'-X-None"], dims=["X"], - ).astype(dtype) + ).astype(np.unicode_) res = values.str.format(pos0, pos1, pos2, X=X, Y=Y, ZZ=ZZ, W=W) - assert res.dtype == targ.dtype - assert_equal(res, targ) + assert res.dtype == expected.dtype + assert_equal(res, expected) -def test_format_broadcast(dtype): - dtype = np.unicode_ +def test_format_broadcast(): values = xr.DataArray( ["{}.{Y}.{ZZ}", "{},{},{X},{X}", "{X}-{Y}-{ZZ}"], dims=["X"], - ).astype(dtype) + ).astype(np.unicode_) pos0 = 1 pos1 = 1.2 @@ -2848,16 +3505,109 @@ def test_format_broadcast(dtype): ZZ = None W = "NO!" - targ = xr.DataArray( + expected = xr.DataArray( [ ["1.X.None", "1.X.None"], ["1,1.2,'test','test'", "1,1.2,'test','test'"], ["'test'-X-None", "'test'-X-None"], ], dims=["X", "YY"], - ).astype(dtype) + ).astype(np.unicode_) res = values.str.format(pos0, pos1, pos2, X=X, Y=Y, ZZ=ZZ, W=W) - assert res.dtype == targ.dtype - assert_equal(res, targ) + assert res.dtype == expected.dtype + assert_equal(res, expected) + + +def test_mod_scalar(): + values = xr.DataArray( + ["%s.%s.%s", "%s,%s,%s", "%s-%s-%s"], + dims=["X"], + ).astype(np.unicode_) + + pos0 = 1 + pos1 = 1.2 + pos2 = "2.3" + + expected = xr.DataArray( + ["1.1.2.2.3", "1,1.2,2.3", "1-1.2-2.3"], + dims=["X"], + ).astype(np.unicode_) + + res = values.str % (pos0, pos1, pos2) + + assert res.dtype == expected.dtype + assert_equal(res, expected) + + +def test_mod_dict(): + values = xr.DataArray( + ["%(a)s.%(a)s.%(b)s", "%(b)s,%(c)s,%(b)s", "%(c)s-%(b)s-%(a)s"], + dims=["X"], + ).astype(np.unicode_) + + a = 1 + b = 1.2 + c = "2.3" + + expected = xr.DataArray( + ["1.1.1.2", "1.2,2.3,1.2", "2.3-1.2-1"], + dims=["X"], + ).astype(np.unicode_) + + res = values.str % {"a": a, "b": b, "c": c} + + assert res.dtype == expected.dtype + assert_equal(res, expected) + + +def test_mod_broadcast_single(): + values = xr.DataArray( + ["%s_1", "%s_2", "%s_3"], + dims=["X"], + ).astype(np.unicode_) + + pos = xr.DataArray( + ["2.3", "3.44444"], + dims=["YY"], + ) + + expected = xr.DataArray( + [["2.3_1", "3.44444_1"], ["2.3_2", "3.44444_2"], ["2.3_3", "3.44444_3"]], + dims=["X", "YY"], + ).astype(np.unicode_) + + res = values.str % pos + + assert res.dtype == expected.dtype + assert_equal(res, expected) + + +def test_mod_broadcast_multi(): + values = xr.DataArray( + ["%s.%s.%s", "%s,%s,%s", "%s-%s-%s"], + dims=["X"], + ).astype(np.unicode_) + + pos0 = 1 + pos1 = 1.2 + + pos2 = xr.DataArray( + ["2.3", "3.44444"], + dims=["YY"], + ) + + expected = xr.DataArray( + [ + ["1.1.2.2.3", "1.1.2.3.44444"], + ["1,1.2,2.3", "1,1.2,3.44444"], + ["1-1.2-2.3", "1-1.2-3.44444"], + ], + dims=["X", "YY"], + ).astype(np.unicode_) + + res = values.str % (pos0, pos1, pos2) + + assert res.dtype == expected.dtype + assert_equal(res, expected)