diff --git a/xarray/coding/times.py b/xarray/coding/times.py index 5cdd9472277..eb8c1dbc42a 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -2,6 +2,7 @@ import warnings from datetime import datetime, timedelta from functools import partial +from typing import TYPE_CHECKING import numpy as np import pandas as pd @@ -27,6 +28,9 @@ except ImportError: cftime = None +if TYPE_CHECKING: + from ..core.types import CFCalendar + # standard calendars recognized by cftime _STANDARD_CALENDARS = {"standard", "gregorian", "proleptic_gregorian"} @@ -344,7 +348,7 @@ def _infer_time_units_from_diff(unique_timedeltas): return "seconds" -def infer_calendar_name(dates): +def infer_calendar_name(dates) -> "CFCalendar": """Given an array of datetimes, infer the CF calendar name""" if is_np_datetime_like(dates.dtype): return "proleptic_gregorian" diff --git a/xarray/core/accessor_dt.py b/xarray/core/accessor_dt.py index 7f8bf79a50a..c90ad204a4a 100644 --- a/xarray/core/accessor_dt.py +++ b/xarray/core/accessor_dt.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import warnings +from typing import TYPE_CHECKING, Generic import numpy as np import pandas as pd @@ -11,6 +14,12 @@ ) from .npcompat import DTypeLike from .pycompat import is_duck_dask_array +from .types import T_DataArray + +if TYPE_CHECKING: + from .dataarray import DataArray + from .dataset import Dataset + from .types import CFCalendar def _season_from_months(months): @@ -156,7 +165,7 @@ def _round_field(values, name, freq): return _round_through_series_or_index(values, name, freq) -def _strftime_through_cftimeindex(values, date_format): +def _strftime_through_cftimeindex(values, date_format: str): """Coerce an array of cftime-like values to a CFTimeIndex and access requested datetime component """ @@ -168,7 +177,7 @@ def _strftime_through_cftimeindex(values, date_format): return field_values.values.reshape(values.shape) -def _strftime_through_series(values, date_format): +def _strftime_through_series(values, date_format: str): """Coerce an array of datetime-like values to a pandas Series and apply string formatting """ @@ -190,33 +199,26 @@ def _strftime(values, date_format): return access_method(values, date_format) -class Properties: - def __init__(self, obj): - self._obj = obj +class TimeAccessor(Generic[T_DataArray]): - @staticmethod - def _tslib_field_accessor( - name: str, docstring: str = None, dtype: DTypeLike = None - ): - def f(self, dtype=dtype): - if dtype is None: - dtype = self._obj.dtype - obj_type = type(self._obj) - result = _get_date_field(self._obj.data, name, dtype) - return obj_type( - result, name=name, coords=self._obj.coords, dims=self._obj.dims - ) + __slots__ = ("_obj",) - f.__name__ = name - f.__doc__ = docstring - return property(f) + def __init__(self, obj: T_DataArray) -> None: + self._obj = obj + + def _date_field(self, name: str, dtype: DTypeLike) -> T_DataArray: + if dtype is None: + dtype = self._obj.dtype + obj_type = type(self._obj) + result = _get_date_field(self._obj.data, name, dtype) + return obj_type(result, name=name, coords=self._obj.coords, dims=self._obj.dims) - def _tslib_round_accessor(self, name, freq): + def _tslib_round_accessor(self, name: str, freq: str) -> T_DataArray: obj_type = type(self._obj) result = _round_field(self._obj.data, name, freq) return obj_type(result, name=name, coords=self._obj.coords, dims=self._obj.dims) - def floor(self, freq): + def floor(self, freq: str) -> T_DataArray: """ Round timestamps downward to specified frequency resolution. @@ -233,7 +235,7 @@ def floor(self, freq): return self._tslib_round_accessor("floor", freq) - def ceil(self, freq): + def ceil(self, freq: str) -> T_DataArray: """ Round timestamps upward to specified frequency resolution. @@ -249,7 +251,7 @@ def ceil(self, freq): """ return self._tslib_round_accessor("ceil", freq) - def round(self, freq): + def round(self, freq: str) -> T_DataArray: """ Round timestamps to specified frequency resolution. @@ -266,7 +268,7 @@ def round(self, freq): return self._tslib_round_accessor("round", freq) -class DatetimeAccessor(Properties): +class DatetimeAccessor(TimeAccessor[T_DataArray]): """Access datetime fields for DataArrays with datetime-like dtypes. Fields can be accessed through the `.dt` attribute @@ -301,7 +303,7 @@ class DatetimeAccessor(Properties): """ - def strftime(self, date_format): + def strftime(self, date_format: str) -> T_DataArray: """ Return an array of formatted strings specified by date_format, which supports the same string format as the python standard library. Details @@ -334,7 +336,7 @@ def strftime(self, date_format): result, name="strftime", coords=self._obj.coords, dims=self._obj.dims ) - def isocalendar(self): + def isocalendar(self) -> Dataset: """Dataset containing ISO year, week number, and weekday. Notes @@ -358,31 +360,48 @@ def isocalendar(self): return Dataset(data_vars) - year = Properties._tslib_field_accessor( - "year", "The year of the datetime", np.int64 - ) - month = Properties._tslib_field_accessor( - "month", "The month as January=1, December=12", np.int64 - ) - day = Properties._tslib_field_accessor("day", "The days of the datetime", np.int64) - hour = Properties._tslib_field_accessor( - "hour", "The hours of the datetime", np.int64 - ) - minute = Properties._tslib_field_accessor( - "minute", "The minutes of the datetime", np.int64 - ) - second = Properties._tslib_field_accessor( - "second", "The seconds of the datetime", np.int64 - ) - microsecond = Properties._tslib_field_accessor( - "microsecond", "The microseconds of the datetime", np.int64 - ) - nanosecond = Properties._tslib_field_accessor( - "nanosecond", "The nanoseconds of the datetime", np.int64 - ) + @property + def year(self) -> T_DataArray: + """The year of the datetime""" + return self._date_field("year", np.int64) + + @property + def month(self) -> T_DataArray: + """The month as January=1, December=12""" + return self._date_field("month", np.int64) + + @property + def day(self) -> T_DataArray: + """The days of the datetime""" + return self._date_field("day", np.int64) + + @property + def hour(self) -> T_DataArray: + """The hours of the datetime""" + return self._date_field("hour", np.int64) @property - def weekofyear(self): + def minute(self) -> T_DataArray: + """The minutes of the datetime""" + return self._date_field("minute", np.int64) + + @property + def second(self) -> T_DataArray: + """The seconds of the datetime""" + return self._date_field("second", np.int64) + + @property + def microsecond(self) -> T_DataArray: + """The microseconds of the datetime""" + return self._date_field("microsecond", np.int64) + + @property + def nanosecond(self) -> T_DataArray: + """The nanoseconds of the datetime""" + return self._date_field("nanosecond", np.int64) + + @property + def weekofyear(self) -> DataArray: "The week ordinal of the year" warnings.warn( @@ -396,64 +415,88 @@ def weekofyear(self): return weekofyear week = weekofyear - dayofweek = Properties._tslib_field_accessor( - "dayofweek", "The day of the week with Monday=0, Sunday=6", np.int64 - ) + + @property + def dayofweek(self) -> T_DataArray: + """The day of the week with Monday=0, Sunday=6""" + return self._date_field("dayofweek", np.int64) + weekday = dayofweek - weekday_name = Properties._tslib_field_accessor( - "weekday_name", "The name of day in a week", object - ) - - dayofyear = Properties._tslib_field_accessor( - "dayofyear", "The ordinal day of the year", np.int64 - ) - quarter = Properties._tslib_field_accessor("quarter", "The quarter of the date") - days_in_month = Properties._tslib_field_accessor( - "days_in_month", "The number of days in the month", np.int64 - ) + @property + def weekday_name(self) -> T_DataArray: + """The name of day in a week""" + return self._date_field("weekday_name", object) + + @property + def dayofyear(self) -> T_DataArray: + """The ordinal day of the year""" + return self._date_field("dayofyear", np.int64) + + @property + def quarter(self) -> T_DataArray: + """The quarter of the date""" + return self._date_field("quarter", np.int64) + + @property + def days_in_month(self) -> T_DataArray: + """The number of days in the month""" + return self._date_field("days_in_month", np.int64) + daysinmonth = days_in_month - season = Properties._tslib_field_accessor("season", "Season of the year", object) - - time = Properties._tslib_field_accessor( - "time", "Timestamps corresponding to datetimes", object - ) - - date = Properties._tslib_field_accessor( - "date", "Date corresponding to datetimes", object - ) - - is_month_start = Properties._tslib_field_accessor( - "is_month_start", - "Indicates whether the date is the first day of the month.", - bool, - ) - is_month_end = Properties._tslib_field_accessor( - "is_month_end", "Indicates whether the date is the last day of the month.", bool - ) - is_quarter_start = Properties._tslib_field_accessor( - "is_quarter_start", - "Indicator for whether the date is the first day of a quarter.", - bool, - ) - is_quarter_end = Properties._tslib_field_accessor( - "is_quarter_end", - "Indicator for whether the date is the last day of a quarter.", - bool, - ) - is_year_start = Properties._tslib_field_accessor( - "is_year_start", "Indicate whether the date is the first day of a year.", bool - ) - is_year_end = Properties._tslib_field_accessor( - "is_year_end", "Indicate whether the date is the last day of the year.", bool - ) - is_leap_year = Properties._tslib_field_accessor( - "is_leap_year", "Boolean indicator if the date belongs to a leap year.", bool - ) + @property + def season(self) -> T_DataArray: + """Season of the year""" + return self._date_field("season", object) + + @property + def time(self) -> T_DataArray: + """Timestamps corresponding to datetimes""" + return self._date_field("time", object) + + @property + def date(self) -> T_DataArray: + """Date corresponding to datetimes""" + return self._date_field("date", object) + + @property + def is_month_start(self) -> T_DataArray: + """Indicate whether the date is the first day of the month""" + return self._date_field("is_month_start", bool) + + @property + def is_month_end(self) -> T_DataArray: + """Indicate whether the date is the last day of the month""" + return self._date_field("is_month_end", bool) + + @property + def is_quarter_start(self) -> T_DataArray: + """Indicate whether the date is the first day of a quarter""" + return self._date_field("is_quarter_start", bool) + + @property + def is_quarter_end(self) -> T_DataArray: + """Indicate whether the date is the last day of a quarter""" + return self._date_field("is_quarter_end", bool) @property - def calendar(self): + def is_year_start(self) -> T_DataArray: + """Indicate whether the date is the first day of a year""" + return self._date_field("is_year_start", bool) + + @property + def is_year_end(self) -> T_DataArray: + """Indicate whether the date is the last day of the year""" + return self._date_field("is_year_end", bool) + + @property + def is_leap_year(self) -> T_DataArray: + """Indicate if the date belongs to a leap year""" + return self._date_field("is_leap_year", bool) + + @property + def calendar(self) -> CFCalendar: """The name of the calendar of the dates. Only relevant for arrays of :py:class:`cftime.datetime` objects, @@ -462,7 +505,7 @@ def calendar(self): return infer_calendar_name(self._obj.data) -class TimedeltaAccessor(Properties): +class TimedeltaAccessor(TimeAccessor[T_DataArray]): """Access Timedelta fields for DataArrays with Timedelta-like dtypes. Fields can be accessed through the `.dt` attribute for applicable DataArrays. @@ -502,28 +545,31 @@ class TimedeltaAccessor(Properties): * time (time) timedelta64[ns] 1 days 00:00:00 ... 5 days 18:00:00 """ - days = Properties._tslib_field_accessor( - "days", "Number of days for each element.", np.int64 - ) - seconds = Properties._tslib_field_accessor( - "seconds", - "Number of seconds (>= 0 and less than 1 day) for each element.", - np.int64, - ) - microseconds = Properties._tslib_field_accessor( - "microseconds", - "Number of microseconds (>= 0 and less than 1 second) for each element.", - np.int64, - ) - nanoseconds = Properties._tslib_field_accessor( - "nanoseconds", - "Number of nanoseconds (>= 0 and less than 1 microsecond) for each element.", - np.int64, - ) - - -class CombinedDatetimelikeAccessor(DatetimeAccessor, TimedeltaAccessor): - def __new__(cls, obj): + @property + def days(self) -> T_DataArray: + """Number of days for each element""" + return self._date_field("days", np.int64) + + @property + def seconds(self) -> T_DataArray: + """Number of seconds (>= 0 and less than 1 day) for each element""" + return self._date_field("seconds", np.int64) + + @property + def microseconds(self) -> T_DataArray: + """Number of microseconds (>= 0 and less than 1 second) for each element""" + return self._date_field("microseconds", np.int64) + + @property + def nanoseconds(self) -> T_DataArray: + """Number of nanoseconds (>= 0 and less than 1 microsecond) for each element""" + return self._date_field("nanoseconds", np.int64) + + +class CombinedDatetimelikeAccessor( + DatetimeAccessor[T_DataArray], TimedeltaAccessor[T_DataArray] +): + def __new__(cls, obj: T_DataArray) -> CombinedDatetimelikeAccessor: # CombinedDatetimelikeAccessor isn't really instatiated. Instead # we need to choose which parent (datetime or timedelta) is # appropriate. Since we're checking the dtypes anyway, we'll just @@ -537,6 +583,6 @@ def __new__(cls, obj): ) if is_np_timedelta_like(obj.dtype): - return TimedeltaAccessor(obj) + return TimedeltaAccessor(obj) # type: ignore[return-value] else: - return DatetimeAccessor(obj) + return DatetimeAccessor(obj) # type: ignore[return-value] diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index 54c9b857a7a..7f65b3add9b 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -37,27 +37,24 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from __future__ import annotations + import codecs import re import textwrap from functools import reduce from operator import or_ as set_union -from typing import ( - Any, - Callable, - Hashable, - Mapping, - Optional, - Pattern, - Tuple, - Type, - Union, -) +from typing import TYPE_CHECKING, Any, Callable, Generic, Hashable, Mapping, Pattern from unicodedata import normalize import numpy as np from .computation import apply_ufunc +from .npcompat import DTypeLike +from .types import T_DataArray + +if TYPE_CHECKING: + from .dataarray import DataArray _cpython_optimized_encoders = ( "utf-8", @@ -112,10 +109,10 @@ def _apply_str_ufunc( *, func: Callable, obj: Any, - dtype: Union[str, np.dtype, Type] = None, - output_core_dims: Union[list, tuple] = ((),), + dtype: DTypeLike = None, + output_core_dims: list | tuple = ((),), output_sizes: Mapping[Any, int] = None, - func_args: Tuple = (), + func_args: tuple = (), func_kwargs: Mapping = {}, ) -> Any: # TODO handling of na values ? @@ -139,7 +136,7 @@ def _apply_str_ufunc( ) -class StringAccessor: +class StringAccessor(Generic[T_DataArray]): r"""Vectorized string functions for string-like arrays. Similar to pandas, fields can be accessed through the `.str` attribute @@ -204,13 +201,10 @@ class StringAccessor: __slots__ = ("_obj",) - def __init__(self, obj): + def __init__(self, obj: T_DataArray) -> None: self._obj = obj - def _stringify( - self, - invar: Any, - ) -> Union[str, bytes, Any]: + def _stringify(self, invar: Any) -> str | bytes | Any: """ Convert a string-like to the correct string/bytes type. @@ -225,10 +219,10 @@ def _apply( self, *, func: Callable, - dtype: Union[str, np.dtype, Type] = None, - output_core_dims: Union[list, tuple] = ((),), + dtype: DTypeLike = None, + output_core_dims: list | tuple = ((),), output_sizes: Mapping[Any, int] = None, - func_args: Tuple = (), + func_args: tuple = (), func_kwargs: Mapping = {}, ) -> Any: return _apply_str_ufunc( @@ -244,10 +238,10 @@ def _apply( def _re_compile( self, *, - pat: Union[str, bytes, Pattern, Any], + pat: str | bytes | Pattern | Any, flags: int = 0, case: bool = None, - ) -> Union[Pattern, Any]: + ) -> Pattern | Any: is_compiled_re = isinstance(pat, re.Pattern) if is_compiled_re and flags != 0: @@ -281,7 +275,7 @@ def func(x): else: return _apply_str_ufunc(func=func, obj=pat, dtype=np.object_) - def len(self) -> Any: + def len(self) -> T_DataArray: """ Compute the length of each string in the array. @@ -293,22 +287,19 @@ def len(self) -> Any: def __getitem__( self, - key: Union[int, slice], + key: int | slice, ) -> Any: if isinstance(key, slice): return self.slice(start=key.start, stop=key.stop, step=key.step) else: return self.get(key) - def __add__( - self, - other: Any, - ) -> Any: + def __add__(self, other: Any) -> T_DataArray: return self.cat(other, sep="") def __mul__( self, - num: Union[int, Any], + num: int | Any, ) -> Any: return self.repeat(num) @@ -327,8 +318,8 @@ def __mod__( def get( self, - i: Union[int, Any], - default: Union[str, bytes] = "", + i: int | Any, + default: str | bytes = "", ) -> Any: """ Extract character number `i` from each string in the array. @@ -360,9 +351,9 @@ def f(x, iind): def slice( self, - start: Union[int, Any] = None, - stop: Union[int, Any] = None, - step: Union[int, Any] = None, + start: int | Any = None, + stop: int | Any = None, + step: int | Any = None, ) -> Any: """ Slice substrings from each string in the array. @@ -391,9 +382,9 @@ def slice( def slice_replace( self, - start: Union[int, Any] = None, - stop: Union[int, Any] = None, - repl: Union[str, bytes, Any] = "", + start: int | Any = None, + stop: int | Any = None, + repl: str | bytes | Any = "", ) -> Any: """ Replace a positional slice of a string with another value. @@ -436,11 +427,7 @@ def func(x, istart, istop, irepl): return self._apply(func=func, func_args=(start, stop, repl)) - def cat( - self, - *others, - sep: Union[str, bytes, Any] = "", - ) -> Any: + def cat(self, *others, sep: str | bytes | Any = "") -> T_DataArray: """ Concatenate strings elementwise in the DataArray with other strings. @@ -524,8 +511,8 @@ def cat( def join( self, dim: Hashable = None, - sep: Union[str, bytes, Any] = "", - ) -> Any: + sep: str | bytes | Any = "", + ) -> T_DataArray: """ Concatenate strings in a DataArray along a particular dimension. @@ -596,7 +583,7 @@ def format( self, *args: Any, **kwargs: Any, - ) -> Any: + ) -> T_DataArray: """ Perform python string formatting on each element of the DataArray. @@ -676,7 +663,7 @@ def format( ) return self._apply(func=func, func_args=args, func_kwargs={"kwargs": kwargs}) - def capitalize(self) -> Any: + def capitalize(self) -> T_DataArray: """ Convert strings in the array to be capitalized. @@ -686,7 +673,7 @@ def capitalize(self) -> Any: """ return self._apply(func=lambda x: x.capitalize()) - def lower(self) -> Any: + def lower(self) -> T_DataArray: """ Convert strings in the array to lowercase. @@ -696,7 +683,7 @@ def lower(self) -> Any: """ return self._apply(func=lambda x: x.lower()) - def swapcase(self) -> Any: + def swapcase(self) -> T_DataArray: """ Convert strings in the array to be swapcased. @@ -706,7 +693,7 @@ def swapcase(self) -> Any: """ return self._apply(func=lambda x: x.swapcase()) - def title(self) -> Any: + def title(self) -> T_DataArray: """ Convert strings in the array to titlecase. @@ -716,7 +703,7 @@ def title(self) -> Any: """ return self._apply(func=lambda x: x.title()) - def upper(self) -> Any: + def upper(self) -> T_DataArray: """ Convert strings in the array to uppercase. @@ -726,7 +713,7 @@ def upper(self) -> Any: """ return self._apply(func=lambda x: x.upper()) - def casefold(self) -> Any: + def casefold(self) -> T_DataArray: """ Convert strings in the array to be casefolded. @@ -744,7 +731,7 @@ def casefold(self) -> Any: def normalize( self, form: str, - ) -> Any: + ) -> T_DataArray: """ Return the Unicode normal form for the strings in the datarray. @@ -763,7 +750,7 @@ def normalize( """ return self._apply(func=lambda x: normalize(form, x)) - def isalnum(self) -> Any: + def isalnum(self) -> T_DataArray: """ Check whether all characters in each string are alphanumeric. @@ -774,7 +761,7 @@ def isalnum(self) -> Any: """ return self._apply(func=lambda x: x.isalnum(), dtype=bool) - def isalpha(self) -> Any: + def isalpha(self) -> T_DataArray: """ Check whether all characters in each string are alphabetic. @@ -785,7 +772,7 @@ def isalpha(self) -> Any: """ return self._apply(func=lambda x: x.isalpha(), dtype=bool) - def isdecimal(self) -> Any: + def isdecimal(self) -> T_DataArray: """ Check whether all characters in each string are decimal. @@ -796,7 +783,7 @@ def isdecimal(self) -> Any: """ return self._apply(func=lambda x: x.isdecimal(), dtype=bool) - def isdigit(self) -> Any: + def isdigit(self) -> T_DataArray: """ Check whether all characters in each string are digits. @@ -807,7 +794,7 @@ def isdigit(self) -> Any: """ return self._apply(func=lambda x: x.isdigit(), dtype=bool) - def islower(self) -> Any: + def islower(self) -> T_DataArray: """ Check whether all characters in each string are lowercase. @@ -818,7 +805,7 @@ def islower(self) -> Any: """ return self._apply(func=lambda x: x.islower(), dtype=bool) - def isnumeric(self) -> Any: + def isnumeric(self) -> T_DataArray: """ Check whether all characters in each string are numeric. @@ -829,7 +816,7 @@ def isnumeric(self) -> Any: """ return self._apply(func=lambda x: x.isnumeric(), dtype=bool) - def isspace(self) -> Any: + def isspace(self) -> T_DataArray: """ Check whether all characters in each string are spaces. @@ -840,7 +827,7 @@ def isspace(self) -> Any: """ return self._apply(func=lambda x: x.isspace(), dtype=bool) - def istitle(self) -> Any: + def istitle(self) -> T_DataArray: """ Check whether all characters in each string are titlecase. @@ -851,7 +838,7 @@ def istitle(self) -> Any: """ return self._apply(func=lambda x: x.istitle(), dtype=bool) - def isupper(self) -> Any: + def isupper(self) -> T_DataArray: """ Check whether all characters in each string are uppercase. @@ -863,11 +850,8 @@ def isupper(self) -> Any: return self._apply(func=lambda x: x.isupper(), dtype=bool) def count( - self, - pat: Union[str, bytes, Pattern, Any], - flags: int = 0, - case: bool = None, - ) -> Any: + self, pat: str | bytes | Pattern | Any, flags: int = 0, case: bool = None + ) -> T_DataArray: """ Count occurrences of pattern in each string of the array. @@ -903,10 +887,7 @@ def count( func = lambda x, ipat: len(ipat.findall(x)) return self._apply(func=func, func_args=(pat,), dtype=int) - def startswith( - self, - pat: Union[str, bytes, Any], - ) -> Any: + def startswith(self, pat: str | bytes | Any) -> T_DataArray: """ Test if the start of each string in the array matches a pattern. @@ -929,10 +910,7 @@ def startswith( func = lambda x, y: x.startswith(y) return self._apply(func=func, func_args=(pat,), dtype=bool) - def endswith( - self, - pat: Union[str, bytes, Any], - ) -> Any: + def endswith(self, pat: str | bytes | Any) -> T_DataArray: """ Test if the end of each string in the array matches a pattern. @@ -957,10 +935,10 @@ def endswith( def pad( self, - width: Union[int, Any], + width: int | Any, side: str = "left", - fillchar: Union[str, bytes, Any] = " ", - ) -> Any: + fillchar: str | bytes | Any = " ", + ) -> T_DataArray: """ Pad strings in the array up to width. @@ -999,9 +977,9 @@ def _padder( self, *, func: Callable, - width: Union[int, Any], - fillchar: Union[str, bytes, Any] = " ", - ) -> Any: + width: int | Any, + fillchar: str | bytes | Any = " ", + ) -> T_DataArray: """ Wrapper function to handle padding operations """ @@ -1015,10 +993,8 @@ def overfunc(x, iwidth, ifillchar): return self._apply(func=overfunc, func_args=(width, fillchar)) def center( - self, - width: Union[int, Any], - fillchar: Union[str, bytes, Any] = " ", - ) -> Any: + self, width: int | Any, fillchar: str | bytes | Any = " " + ) -> T_DataArray: """ Pad left and right side of each string in the array. @@ -1043,9 +1019,9 @@ def center( def ljust( self, - width: Union[int, Any], - fillchar: Union[str, bytes, Any] = " ", - ) -> Any: + width: int | Any, + fillchar: str | bytes | Any = " ", + ) -> T_DataArray: """ Pad right side of each string in the array. @@ -1070,9 +1046,9 @@ def ljust( def rjust( self, - width: Union[int, Any], - fillchar: Union[str, bytes, Any] = " ", - ) -> Any: + width: int | Any, + fillchar: str | bytes | Any = " ", + ) -> T_DataArray: """ Pad left side of each string in the array. @@ -1095,7 +1071,7 @@ def rjust( func = self._obj.dtype.type.rjust return self._padder(func=func, width=width, fillchar=fillchar) - def zfill(self, width: Union[int, Any]) -> Any: + def zfill(self, width: int | Any) -> T_DataArray: """ Pad each string in the array by prepending '0' characters. @@ -1120,11 +1096,11 @@ def zfill(self, width: Union[int, Any]) -> Any: def contains( self, - pat: Union[str, bytes, Pattern, Any], + pat: str | bytes | Pattern | Any, case: bool = None, flags: int = 0, regex: bool = True, - ) -> Any: + ) -> T_DataArray: """ Test if pattern or regex is contained within each string of the array. @@ -1181,22 +1157,22 @@ def func(x, ipat): if case or case is None: func = lambda x, ipat: ipat in x elif self._obj.dtype.char == "U": - uppered = self._obj.str.casefold() - uppat = StringAccessor(pat).casefold() - return uppered.str.contains(uppat, regex=False) + uppered = self.casefold() + uppat = StringAccessor(pat).casefold() # type: ignore[type-var] # hack? + return uppered.str.contains(uppat, regex=False) # type: ignore[return-value] else: - uppered = self._obj.str.upper() - uppat = StringAccessor(pat).upper() - return uppered.str.contains(uppat, regex=False) + uppered = self.upper() + uppat = StringAccessor(pat).upper() # type: ignore[type-var] # hack? + return uppered.str.contains(uppat, regex=False) # type: ignore[return-value] return self._apply(func=func, func_args=(pat,), dtype=bool) def match( self, - pat: Union[str, bytes, Pattern, Any], + pat: str | bytes | Pattern | Any, case: bool = None, flags: int = 0, - ) -> Any: + ) -> T_DataArray: """ Determine if each string in the array matches a regular expression. @@ -1229,10 +1205,8 @@ def match( return self._apply(func=func, func_args=(pat,), dtype=bool) def strip( - self, - to_strip: Union[str, bytes, Any] = None, - side: str = "both", - ) -> Any: + self, to_strip: str | bytes | Any = None, side: str = "both" + ) -> T_DataArray: """ Remove leading and trailing characters. @@ -1269,10 +1243,7 @@ def strip( return self._apply(func=func, func_args=(to_strip,)) - def lstrip( - self, - to_strip: Union[str, bytes, Any] = None, - ) -> Any: + def lstrip(self, to_strip: str | bytes | Any = None) -> T_DataArray: """ Remove leading characters. @@ -1295,10 +1266,7 @@ def lstrip( """ return self.strip(to_strip, side="left") - def rstrip( - self, - to_strip: Union[str, bytes, Any] = None, - ) -> Any: + def rstrip(self, to_strip: str | bytes | Any = None) -> T_DataArray: """ Remove trailing characters. @@ -1321,11 +1289,7 @@ def rstrip( """ return self.strip(to_strip, side="right") - def wrap( - self, - width: Union[int, Any], - **kwargs, - ) -> Any: + def wrap(self, width: int | Any, **kwargs) -> T_DataArray: """ Wrap long strings in the array in paragraphs with length less than `width`. @@ -1348,14 +1312,11 @@ def wrap( wrapped : same type as values """ ifunc = lambda x: textwrap.TextWrapper(width=x, **kwargs) - tw = StringAccessor(width)._apply(func=ifunc, dtype=np.object_) + tw = StringAccessor(width)._apply(func=ifunc, dtype=np.object_) # type: ignore[type-var] # hack? func = lambda x, itw: "\n".join(itw.wrap(x)) return self._apply(func=func, func_args=(tw,)) - def translate( - self, - table: Mapping[Union[str, bytes], Union[str, bytes]], - ) -> Any: + def translate(self, table: Mapping[str | bytes, str | bytes]) -> T_DataArray: """ Map characters of each string through the given mapping table. @@ -1376,8 +1337,8 @@ def translate( def repeat( self, - repeats: Union[int, Any], - ) -> Any: + repeats: int | Any, + ) -> T_DataArray: """ Repeat each string in the array. @@ -1400,11 +1361,11 @@ def repeat( def find( self, - sub: Union[str, bytes, Any], - start: Union[int, Any] = 0, - end: Union[int, Any] = None, + sub: str | bytes | Any, + start: int | Any = 0, + end: int | Any = None, side: str = "left", - ) -> Any: + ) -> T_DataArray: """ Return lowest or highest indexes in each strings in the array where the substring is fully contained between [start:end]. @@ -1445,10 +1406,10 @@ def find( def rfind( self, - sub: Union[str, bytes, Any], - start: Union[int, Any] = 0, - end: Union[int, Any] = None, - ) -> Any: + sub: str | bytes | Any, + start: int | Any = 0, + end: int | Any = None, + ) -> T_DataArray: """ Return highest indexes in each strings in the array where the substring is fully contained between [start:end]. @@ -1477,11 +1438,11 @@ def rfind( def index( self, - sub: Union[str, bytes, Any], - start: Union[int, Any] = 0, - end: Union[int, Any] = None, + sub: str | bytes | Any, + start: int | Any = 0, + end: int | Any = None, side: str = "left", - ) -> Any: + ) -> T_DataArray: """ Return lowest or highest indexes in each strings where the substring is fully contained between [start:end]. This is the same as @@ -1528,10 +1489,10 @@ def index( def rindex( self, - sub: Union[str, bytes, Any], - start: Union[int, Any] = 0, - end: Union[int, Any] = None, - ) -> Any: + sub: str | bytes | Any, + start: int | Any = 0, + end: int | Any = None, + ) -> T_DataArray: """ Return highest indexes in each strings where the substring is fully contained between [start:end]. This is the same as @@ -1566,13 +1527,13 @@ def rindex( def replace( self, - pat: Union[str, bytes, Pattern, Any], - repl: Union[str, bytes, Callable, Any], - n: Union[int, Any] = -1, + pat: str | bytes | Pattern | Any, + repl: str | bytes | Callable | Any, + n: int | Any = -1, case: bool = None, flags: int = 0, regex: bool = True, - ) -> Any: + ) -> T_DataArray: """ Replace occurrences of pattern/regex in the array with some string. @@ -1639,7 +1600,7 @@ def replace( def extract( self, - pat: Union[str, bytes, Pattern, Any], + pat: str | bytes | Pattern | Any, dim: Hashable, case: bool = None, flags: int = 0, @@ -1783,7 +1744,7 @@ def _get_res_multi(val, pat): def extractall( self, - pat: Union[str, bytes, Pattern, Any], + pat: str | bytes | Pattern | Any, group_dim: Hashable, match_dim: Hashable, case: bool = None, @@ -1958,7 +1919,7 @@ def _get_res(val, ipat, imaxcount=maxcount, dtype=self._obj.dtype): def findall( self, - pat: Union[str, bytes, Pattern, Any], + pat: str | bytes | Pattern | Any, case: bool = None, flags: int = 0, ) -> Any: @@ -2053,9 +2014,9 @@ def _partitioner( self, *, func: Callable, - dim: Hashable, - sep: Optional[Union[str, bytes, Any]], - ) -> Any: + dim: Hashable | None, + sep: str | bytes | Any | None, + ) -> T_DataArray: """ Implements logic for `partition` and `rpartition`. """ @@ -2067,7 +2028,7 @@ def _partitioner( # _apply breaks on an empty array in this case if not self._obj.size: - return self._obj.copy().expand_dims({dim: 0}, axis=-1) + return self._obj.copy().expand_dims({dim: 0}, axis=-1) # type: ignore[return-value] arrfunc = lambda x, isep: np.array(func(x, isep), dtype=self._obj.dtype) @@ -2083,9 +2044,9 @@ def _partitioner( def partition( self, - dim: Optional[Hashable], - sep: Union[str, bytes, Any] = " ", - ) -> Any: + dim: Hashable | None, + sep: str | bytes | Any = " ", + ) -> T_DataArray: """ Split the strings in the DataArray at the first occurrence of separator `sep`. @@ -2103,7 +2064,7 @@ def partition( dim : hashable or None Name for the dimension to place the 3 elements in. If `None`, place the results as list elements in an object DataArray. - sep : str, default: " " + sep : str or bytes or array-like, default: " " String to split on. If array-like, it is broadcast. @@ -2121,9 +2082,9 @@ def partition( def rpartition( self, - dim: Optional[Hashable], - sep: Union[str, bytes, Any] = " ", - ) -> Any: + dim: Hashable | None, + sep: str | bytes | Any = " ", + ) -> T_DataArray: """ Split the strings in the DataArray at the last occurrence of separator `sep`. @@ -2141,7 +2102,7 @@ def rpartition( dim : hashable or None Name for the dimension to place the 3 elements in. If `None`, place the results as list elements in an object DataArray. - sep : str, default: " " + sep : str or bytes or array-like, default: " " String to split on. If array-like, it is broadcast. @@ -2163,9 +2124,9 @@ def _splitter( func: Callable, pre: bool, dim: Hashable, - sep: Optional[Union[str, bytes, Any]], + sep: str | bytes | Any | None, maxsplit: int, - ) -> Any: + ) -> DataArray: """ Implements logic for `split` and `rsplit`. """ @@ -2208,10 +2169,10 @@ def _dosplit(mystr, sep, maxsplit=maxsplit, dtype=self._obj.dtype): def split( self, - dim: Optional[Hashable], - sep: Union[str, bytes, Any] = None, + dim: Hashable | None, + sep: str | bytes | Any = None, maxsplit: int = -1, - ) -> Any: + ) -> DataArray: r""" Split strings in a DataArray around the given separator/delimiter `sep`. @@ -2324,10 +2285,10 @@ def split( def rsplit( self, - dim: Optional[Hashable], - sep: Union[str, bytes, Any] = None, - maxsplit: Union[int, Any] = -1, - ) -> Any: + dim: Hashable | None, + sep: str | bytes | Any = None, + maxsplit: int | Any = -1, + ) -> DataArray: r""" Split strings in a DataArray around the given separator/delimiter `sep`. @@ -2443,8 +2404,8 @@ def rsplit( def get_dummies( self, dim: Hashable, - sep: Union[str, bytes, Any] = "|", - ) -> Any: + sep: str | bytes | Any = "|", + ) -> DataArray: """ Return DataArray of dummy/indicator variables. @@ -2519,11 +2480,7 @@ def get_dummies( res.coords[dim] = vals return res - def decode( - self, - encoding: str, - errors: str = "strict", - ) -> Any: + def decode(self, encoding: str, errors: str = "strict") -> T_DataArray: """ Decode character string in the array using indicated encoding. @@ -2533,7 +2490,7 @@ def decode( The encoding to use. Please see the Python documentation `codecs standard encoders `_ section for a list of encodings handlers. - errors : str, optional + errors : str, default: "strict" The handler for encoding errors. Please see the Python documentation `codecs error handlers `_ for a list of error handlers. @@ -2549,11 +2506,7 @@ def decode( func = lambda x: decoder(x, errors)[0] return self._apply(func=func, dtype=np.str_) - def encode( - self, - encoding: str, - errors: str = "strict", - ) -> Any: + def encode(self, encoding: str, errors: str = "strict") -> T_DataArray: """ Encode character string in the array using indicated encoding. @@ -2563,7 +2516,7 @@ def encode( The encoding to use. Please see the Python documentation `codecs standard encoders `_ section for a list of encodings handlers. - errors : str, optional + errors : str, default: "strict" The handler for encoding errors. Please see the Python documentation `codecs error handlers `_ for a list of error handlers. diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index c1a9192233e..7b206fceeeb 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -765,7 +765,7 @@ def align( def deep_align( - objects, + objects: Iterable[Any], join: JoinOptions = "inner", copy=True, indexes=None, diff --git a/xarray/core/computation.py b/xarray/core/computation.py index da2db39525a..1ca63ff369e 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -282,7 +282,7 @@ def build_output_coords_and_indexes( def apply_dataarray_vfunc( func, *args, - signature, + signature: _UFuncSignature, join: JoinOptions = "inner", exclude_dims=frozenset(), keep_attrs="override", @@ -405,12 +405,12 @@ def _unpack_dict_tuples( def apply_dict_of_variables_vfunc( - func, *args, signature, join="inner", fill_value=None + func, *args, signature: _UFuncSignature, join="inner", fill_value=None ): """Apply a variable level function over dicts of DataArray, DataArray, Variable and ndarray objects. """ - args = [_as_variables_or_variable(arg) for arg in args] + args = tuple(_as_variables_or_variable(arg) for arg in args) names = join_dict_keys(args, how=join) grouped_by_name = collect_dict_values(args, names, fill_value) @@ -443,13 +443,13 @@ def _fast_dataset( def apply_dataset_vfunc( func, *args, - signature, + signature: _UFuncSignature, join="inner", dataset_join="exact", fill_value=_NO_FILL_VALUE, exclude_dims=frozenset(), keep_attrs="override", -): +) -> Dataset | tuple[Dataset, ...]: """Apply a variable level function over Dataset, dict of DataArray, DataArray, Variable and/or ndarray objects. """ @@ -472,12 +472,13 @@ def apply_dataset_vfunc( list_of_coords, list_of_indexes = build_output_coords_and_indexes( args, signature, exclude_dims, combine_attrs=keep_attrs ) - args = [getattr(arg, "data_vars", arg) for arg in args] + args = tuple(getattr(arg, "data_vars", arg) for arg in args) result_vars = apply_dict_of_variables_vfunc( func, *args, signature=signature, join=dataset_join, fill_value=fill_value ) + out: Dataset | tuple[Dataset, ...] if signature.num_outputs > 1: out = tuple( _fast_dataset(*args) @@ -657,14 +658,14 @@ def _vectorize(func, signature, output_dtypes, exclude_dims): def apply_variable_ufunc( func, *args, - signature, + signature: _UFuncSignature, exclude_dims=frozenset(), dask="forbidden", output_dtypes=None, vectorize=False, keep_attrs="override", dask_gufunc_kwargs=None, -): +) -> Variable | tuple[Variable, ...]: """Apply a ndarray level function over Variable and/or ndarray objects.""" from .variable import Variable, as_compatible_data @@ -785,7 +786,7 @@ def func(*arrays): combine_attrs=keep_attrs, ) - output = [] + output: list[Variable] = [] for dims, data in zip(output_dims, result_data): data = as_compatible_data(data) if data.ndim != len(dims): diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 35c0aab3fb8..80a55ff7d90 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -356,7 +356,7 @@ class DataArray( _resample_cls = resample.DataArrayResample _weighted_cls = weighted.DataArrayWeighted - dt = utils.UncachedAccessor(CombinedDatetimelikeAccessor) + dt = utils.UncachedAccessor(CombinedDatetimelikeAccessor["DataArray"]) def __init__( self, @@ -5077,4 +5077,4 @@ def interp_calendar( # this needs to be at the end, or mypy will confuse with `str` # https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names - str = utils.UncachedAccessor(StringAccessor) + str = utils.UncachedAccessor(StringAccessor["DataArray"]) diff --git a/xarray/core/types.py b/xarray/core/types.py index 5acf4f7b587..a2f268b983a 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -46,6 +46,18 @@ ] JoinOptions = Literal["outer", "inner", "left", "right", "exact", "override"] +CFCalendar = Literal[ + "standard", + "gregorian", + "proleptic_gregorian", + "noleap", + "365_day", + "360_day", + "julian", + "all_leap", + "366_day", +] + # TODO: Wait until mypy supports recursive objects in combination with typevars _T = TypeVar("_T") NestedSequence = Union[ diff --git a/xarray/core/utils.py b/xarray/core/utils.py index a2f95123892..121710d19aa 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -16,6 +16,7 @@ Callable, Collection, Container, + Generic, Hashable, Iterable, Iterator, @@ -24,6 +25,7 @@ MutableSet, TypeVar, cast, + overload, ) import numpy as np @@ -896,7 +898,10 @@ def drop_missing_dims( ) -class UncachedAccessor: +_Accessor = TypeVar("_Accessor") + + +class UncachedAccessor(Generic[_Accessor]): """Acts like a property, but on both classes and class instances This class is necessary because some tools (e.g. pydoc and sphinx) @@ -904,14 +909,22 @@ class UncachedAccessor: accessor. """ - def __init__(self, accessor): + def __init__(self, accessor: type[_Accessor]) -> None: self._accessor = accessor - def __get__(self, obj, cls): + @overload + def __get__(self, obj: None, cls) -> type[_Accessor]: + ... + + @overload + def __get__(self, obj: object, cls) -> _Accessor: + ... + + def __get__(self, obj: None | object, cls) -> type[_Accessor] | _Accessor: if obj is None: return self._accessor - return self._accessor(obj) + return self._accessor(obj) # type: ignore # assume it is a valid accessor! # Singleton type, as per https://github.com/python/typing/pull/240