diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py index df901f9a1d9..2f29a32dbf5 100644 --- a/xarray/backends/file_manager.py +++ b/xarray/backends/file_manager.py @@ -11,8 +11,8 @@ from xarray.backends.locks import acquire from xarray.backends.lru_cache import LRUCache -from xarray.core import utils from xarray.core.options import OPTIONS +from xarray.namedarray import utils # Global cache for storing open files. FILE_CACHE: LRUCache[Any, io.IOBase] = LRUCache( diff --git a/xarray/core/computation.py b/xarray/core/computation.py index bae779af652..1405ca1805a 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -13,7 +13,7 @@ import numpy as np -from xarray.core import dtypes, duck_array_ops, utils +from xarray.core import dtypes, duck_array_ops from xarray.core.alignment import align, deep_align from xarray.core.common import zeros_like from xarray.core.duck_array_ops import datetime_to_numeric @@ -26,6 +26,7 @@ from xarray.core.types import Dims, T_DataArray from xarray.core.utils import is_dict_like, is_scalar from xarray.core.variable import Variable +from xarray.namedarray import utils if TYPE_CHECKING: from xarray.core.coordinates import Coordinates diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 0c85b2a2d69..c4957a36351 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -26,11 +26,11 @@ from xarray.core.types import DataVars, Self, T_DataArray, T_Xarray from xarray.core.utils import ( Frozen, - ReprObject, either_dict_or_kwargs, emit_user_level_warning, ) from xarray.core.variable import Variable, as_variable, calculate_dimensions +from xarray.namedarray.utils import ReprObject if TYPE_CHECKING: from xarray.core.common import DataWithCoords diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 0b9786dc2b7..58101be6cae 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -53,7 +53,6 @@ from xarray.core.utils import ( Default, HybridMappingProxy, - ReprObject, _default, either_dict_or_kwargs, emit_user_level_warning, @@ -64,6 +63,7 @@ as_compatible_data, as_variable, ) +from xarray.namedarray.utils import ReprObject from xarray.plot.accessor import DataArrayPlotAccessor from xarray.plot.utils import _get_units_from_attrs diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py index 0762fa03112..0c1dbf518c9 100644 --- a/xarray/core/dtypes.py +++ b/xarray/core/dtypes.py @@ -1,188 +1,3 @@ from __future__ import annotations -import functools - -import numpy as np - -from xarray.core import utils - -# Use as a sentinel value to indicate a dtype appropriate NA value. -NA = utils.ReprObject("") - - -@functools.total_ordering -class AlwaysGreaterThan: - def __gt__(self, other): - return True - - def __eq__(self, other): - return isinstance(other, type(self)) - - -@functools.total_ordering -class AlwaysLessThan: - def __lt__(self, other): - return True - - def __eq__(self, other): - return isinstance(other, type(self)) - - -# Equivalence to np.inf (-np.inf) for object-type -INF = AlwaysGreaterThan() -NINF = AlwaysLessThan() - - -# Pairs of types that, if both found, should be promoted to object dtype -# instead of following NumPy's own type-promotion rules. These type promotion -# rules match pandas instead. For reference, see the NumPy type hierarchy: -# https://numpy.org/doc/stable/reference/arrays.scalars.html -PROMOTE_TO_OBJECT: tuple[tuple[type[np.generic], type[np.generic]], ...] = ( - (np.number, np.character), # numpy promotes to character - (np.bool_, np.character), # numpy promotes to character - (np.bytes_, np.str_), # numpy promotes to unicode -) - - -def maybe_promote(dtype): - """Simpler equivalent of pandas.core.common._maybe_promote - - Parameters - ---------- - dtype : np.dtype - - Returns - ------- - dtype : Promoted dtype that can hold missing values. - fill_value : Valid missing value for the promoted dtype. - """ - # N.B. these casting rules should match pandas - if np.issubdtype(dtype, np.floating): - fill_value = np.nan - elif np.issubdtype(dtype, np.timedelta64): - # See https://github.com/numpy/numpy/issues/10685 - # np.timedelta64 is a subclass of np.integer - # Check np.timedelta64 before np.integer - fill_value = np.timedelta64("NaT") - elif np.issubdtype(dtype, np.integer): - dtype = np.float32 if dtype.itemsize <= 2 else np.float64 - fill_value = np.nan - elif np.issubdtype(dtype, np.complexfloating): - fill_value = np.nan + np.nan * 1j - elif np.issubdtype(dtype, np.datetime64): - fill_value = np.datetime64("NaT") - else: - dtype = object - fill_value = np.nan - - dtype = np.dtype(dtype) - fill_value = dtype.type(fill_value) - return dtype, fill_value - - -NAT_TYPES = {np.datetime64("NaT").dtype, np.timedelta64("NaT").dtype} - - -def get_fill_value(dtype): - """Return an appropriate fill value for this dtype. - - Parameters - ---------- - dtype : np.dtype - - Returns - ------- - fill_value : Missing value corresponding to this dtype. - """ - _, fill_value = maybe_promote(dtype) - return fill_value - - -def get_pos_infinity(dtype, max_for_int=False): - """Return an appropriate positive infinity for this dtype. - - Parameters - ---------- - dtype : np.dtype - max_for_int : bool - Return np.iinfo(dtype).max instead of np.inf - - Returns - ------- - fill_value : positive infinity value corresponding to this dtype. - """ - if issubclass(dtype.type, np.floating): - return np.inf - - if issubclass(dtype.type, np.integer): - if max_for_int: - return np.iinfo(dtype).max - else: - return np.inf - - if issubclass(dtype.type, np.complexfloating): - return np.inf + 1j * np.inf - - return INF - - -def get_neg_infinity(dtype, min_for_int=False): - """Return an appropriate positive infinity for this dtype. - - Parameters - ---------- - dtype : np.dtype - min_for_int : bool - Return np.iinfo(dtype).min instead of -np.inf - - Returns - ------- - fill_value : positive infinity value corresponding to this dtype. - """ - if issubclass(dtype.type, np.floating): - return -np.inf - - if issubclass(dtype.type, np.integer): - if min_for_int: - return np.iinfo(dtype).min - else: - return -np.inf - - if issubclass(dtype.type, np.complexfloating): - return -np.inf - 1j * np.inf - - return NINF - - -def is_datetime_like(dtype): - """Check if a dtype is a subclass of the numpy datetime types""" - return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64) - - -def result_type( - *arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike, -) -> np.dtype: - """Like np.result_type, but with type promotion rules matching pandas. - - Examples of changed behavior: - number + string -> object (not string) - bytes + unicode -> object (not unicode) - - Parameters - ---------- - *arrays_and_dtypes : list of arrays and dtypes - The dtype is extracted from both numpy and dask arrays. - - Returns - ------- - numpy.dtype for the result. - """ - types = {np.result_type(t).type for t in arrays_and_dtypes} - - for left, right in PROMOTE_TO_OBJECT: - if any(issubclass(t, left) for t in types) and any( - issubclass(t, right) for t in types - ): - return np.dtype(object) - - return np.result_type(*arrays_and_dtypes) +from xarray.namedarray.dtypes import * # noqa: F401, F403 diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py index 9af5d693170..9556bb61b67 100644 --- a/xarray/core/pycompat.py +++ b/xarray/core/pycompat.py @@ -1,103 +1 @@ -from __future__ import annotations - -from importlib import import_module -from types import ModuleType -from typing import TYPE_CHECKING, Any, Literal - -import numpy as np -from packaging.version import Version - -from xarray.core.utils import is_duck_array, is_scalar, module_available - -integer_types = (int, np.integer) - -if TYPE_CHECKING: - ModType = Literal["dask", "pint", "cupy", "sparse", "cubed"] - DuckArrayTypes = tuple[type[Any], ...] # TODO: improve this? maybe Generic - - -class DuckArrayModule: - """ - Solely for internal isinstance and version checks. - - Motivated by having to only import pint when required (as pint currently imports xarray) - https://github.com/pydata/xarray/pull/5561#discussion_r664815718 - """ - - module: ModuleType | None - version: Version - type: DuckArrayTypes - available: bool - - def __init__(self, mod: ModType) -> None: - duck_array_module: ModuleType | None - duck_array_version: Version - duck_array_type: DuckArrayTypes - try: - duck_array_module = import_module(mod) - duck_array_version = Version(duck_array_module.__version__) - - if mod == "dask": - duck_array_type = (import_module("dask.array").Array,) - elif mod == "pint": - duck_array_type = (duck_array_module.Quantity,) - elif mod == "cupy": - duck_array_type = (duck_array_module.ndarray,) - elif mod == "sparse": - duck_array_type = (duck_array_module.SparseArray,) - elif mod == "cubed": - duck_array_type = (duck_array_module.Array,) - else: - raise NotImplementedError - - except (ImportError, AttributeError): # pragma: no cover - duck_array_module = None - duck_array_version = Version("0.0.0") - duck_array_type = () - - self.module = duck_array_module - self.version = duck_array_version - self.type = duck_array_type - self.available = duck_array_module is not None - - -_cached_duck_array_modules: dict[ModType, DuckArrayModule] = {} - - -def _get_cached_duck_array_module(mod: ModType) -> DuckArrayModule: - if mod not in _cached_duck_array_modules: - duckmod = DuckArrayModule(mod) - _cached_duck_array_modules[mod] = duckmod - return duckmod - else: - return _cached_duck_array_modules[mod] - - -def array_type(mod: ModType) -> DuckArrayTypes: - """Quick wrapper to get the array class of the module.""" - return _get_cached_duck_array_module(mod).type - - -def mod_version(mod: ModType) -> Version: - """Quick wrapper to get the version of the module.""" - return _get_cached_duck_array_module(mod).version - - -def is_dask_collection(x): - if module_available("dask"): - from dask.base import is_dask_collection - - return is_dask_collection(x) - return False - - -def is_duck_dask_array(x): - return is_duck_array(x) and is_dask_collection(x) - - -def is_chunked_array(x) -> bool: - return is_duck_dask_array(x) or (is_duck_array(x) and hasattr(x, "chunks")) - - -def is_0d_dask_array(x): - return is_duck_dask_array(x) and is_scalar(x) +from xarray.namedarray.pycompat import * # noqa: F401, F403 diff --git a/xarray/core/utils.py b/xarray/core/utils.py index ad86b2c7fec..065bac4d5b9 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -605,31 +605,6 @@ def __repr__(self: Any) -> str: return f"{type(self).__name__}(array={self.array!r})" -class ReprObject: - """Object that prints as the given value, for use with sentinel values.""" - - __slots__ = ("_value",) - - def __init__(self, value: str): - self._value = value - - def __repr__(self) -> str: - return self._value - - def __eq__(self, other) -> bool: - if isinstance(other, ReprObject): - return self._value == other._value - return False - - def __hash__(self) -> int: - return hash((type(self), self._value)) - - def __dask_tokenize__(self): - from dask.base import normalize_token - - return normalize_token((type(self), self._value)) - - @contextlib.contextmanager def close_on_error(f): """Context manager to ensure that a file opened by xarray is closed if an diff --git a/xarray/namedarray/arithmetic.py b/xarray/namedarray/arithmetic.py new file mode 100644 index 00000000000..c257998e265 --- /dev/null +++ b/xarray/namedarray/arithmetic.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import textwrap +import typing + + +class ImplementsArrayReduce: + __slots__ = () + + @classmethod + def _reduce_method( + cls, func: typing.Callable, include_skipna: bool, numeric_only: bool + ): + if include_skipna: + + def wrapped_func(self, dim=None, axis=None, skipna=None, **kwargs): + return self.reduce(func, dim, axis, skipna=skipna, **kwargs) + + else: + + def wrapped_func(self, dim=None, axis=None, **kwargs): # type: ignore[misc] + return self.reduce(func, dim, axis, **kwargs) + + return wrapped_func + + _reduce_extra_args_docstring = textwrap.dedent( + """\ + dim : str or sequence of str, optional + Dimension(s) over which to apply `{name}`. + axis : int or sequence of int, optional + Axis(es) over which to apply `{name}`. Only one of the 'dim' + and 'axis' arguments can be supplied. If neither are supplied, then + `{name}` is calculated over axes.""" + ) + + _cum_extra_args_docstring = textwrap.dedent( + """\ + dim : str or sequence of str, optional + Dimension over which to apply `{name}`. + axis : int or sequence of int, optional + Axis over which to apply `{name}`. Only one of the 'dim' + and 'axis' arguments can be supplied.""" + ) + + +class IncludeReduceMethods: + __slots__ = () + + +class IncludeCumMethods: + ... + + +class IncludeNumpySameMethods: + ... + + +class SupportsArithmetic: + ... + + +class NamedArrayOpsMixin: + ... + + +class NamedArrayArithmetic: + __slots__ = () + # prioritize our operations over those of numpy.ndarray (priority=0) + __array_priority__ = 50 diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 16a7b422f1b..49fa393e5b1 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -4,27 +4,22 @@ import math import sys import typing -from collections.abc import Hashable, Iterable, Mapping +from collections.abc import Hashable, Mapping import numpy as np # TODO: get rid of this after migrating this class to array API -from xarray.core import dtypes from xarray.core.indexing import ExplicitlyIndexed from xarray.core.utils import Default, _default +from xarray.namedarray import dtypes +from xarray.namedarray.arithmetic import NamedArrayArithmetic +from xarray.namedarray.types import Dims, DimsInput, T_DuckArray from xarray.namedarray.utils import ( - T_DuckArray, is_duck_array, is_duck_dask_array, to_0d_object_array, ) -if typing.TYPE_CHECKING: - T_NamedArray = typing.TypeVar("T_NamedArray", bound="NamedArray") - DimsInput = typing.Union[str, Iterable[Hashable]] - Dims = tuple[Hashable, ...] - - try: if sys.version_info >= (3, 11): from typing import Self @@ -68,7 +63,7 @@ def as_compatible_data( return typing.cast(T_DuckArray, np.asarray(data)) -class NamedArray: +class NamedArray(NamedArrayArithmetic): """A lightweight wrapper around duck arrays with named dimensions and attributes which describe a single Array. Numeric operations on this object implement array broadcasting and dimension alignment based on dimension names, diff --git a/xarray/namedarray/dtypes.py b/xarray/namedarray/dtypes.py new file mode 100644 index 00000000000..b5f471abec0 --- /dev/null +++ b/xarray/namedarray/dtypes.py @@ -0,0 +1,180 @@ +from __future__ import annotations + +import functools + +import numpy as np + +from xarray.namedarray import utils + +# Use as a sentinel value to indicate a dtype appropriate NA value. +NA = utils.ReprObject("") + + +@functools.total_ordering +class AlwaysGreaterThan: + def __gt__(self, other): + return True + + def __eq__(self, other): + return isinstance(other, type(self)) + + +@functools.total_ordering +class AlwaysLessThan: + def __lt__(self, other): + return True + + def __eq__(self, other): + return isinstance(other, type(self)) + + +# Equivalence to np.inf (-np.inf) for object-type +INF = AlwaysGreaterThan() +NINF = AlwaysLessThan() + + +# Pairs of types that, if both found, should be promoted to object dtype +# instead of following NumPy's own type-promotion rules. These type promotion +# rules match pandas instead. For reference, see the NumPy type hierarchy: +# https://numpy.org/doc/stable/reference/arrays.scalars.html +PROMOTE_TO_OBJECT: tuple[tuple[type[np.generic], type[np.generic]], ...] = ( + (np.number, np.character), # numpy promotes to character + (np.bool_, np.character), # numpy promotes to character + (np.bytes_, np.str_), # numpy promotes to unicode +) + + +def maybe_promote(dtype): + """Simpler equivalent of pandas.core.common._maybe_promote + + Parameters + ---------- + dtype : np.dtype + + Returns + ------- + dtype : Promoted dtype that can hold missing values. + fill_value : Valid missing value for the promoted dtype. + """ + # N.B. these casting rules should match pandas + if np.issubdtype(dtype, np.floating): + fill_value = np.nan + elif np.issubdtype(dtype, np.timedelta64): + # See https://github.com/numpy/numpy/issues/10685 + # np.timedelta64 is a subclass of np.integer + # Check np.timedelta64 before np.integer + fill_value = np.timedelta64("NaT") + elif np.issubdtype(dtype, np.integer): + dtype = np.float32 if dtype.itemsize <= 2 else np.float64 + fill_value = np.nan + elif np.issubdtype(dtype, np.complexfloating): + fill_value = np.nan + np.nan * 1j + elif np.issubdtype(dtype, np.datetime64): + fill_value = np.datetime64("NaT") + else: + dtype = object + fill_value = np.nan + + dtype = np.dtype(dtype) + fill_value = dtype.type(fill_value) + return dtype, fill_value + + +NAT_TYPES = {np.datetime64("NaT").dtype, np.timedelta64("NaT").dtype} + + +def get_fill_value(dtype): + """Return an appropriate fill value for this dtype. + + Parameters + ---------- + dtype : np.dtype + + Returns + ------- + fill_value : Missing value corresponding to this dtype. + """ + _, fill_value = maybe_promote(dtype) + return fill_value + + +def get_pos_infinity(dtype, max_for_int=False): + """Return an appropriate positive infinity for this dtype. + + Parameters + ---------- + dtype : np.dtype + max_for_int : bool + Return np.iinfo(dtype).max instead of np.inf + + Returns + ------- + fill_value : positive infinity value corresponding to this dtype. + """ + if issubclass(dtype.type, np.floating): + return np.inf + + if issubclass(dtype.type, np.integer): + return np.iinfo(dtype).max if max_for_int else np.inf + if issubclass(dtype.type, np.complexfloating): + return np.inf + 1j * np.inf + + return INF + + +def get_neg_infinity(dtype, min_for_int=False): + """Return an appropriate positive infinity for this dtype. + + Parameters + ---------- + dtype : np.dtype + min_for_int : bool + Return np.iinfo(dtype).min instead of -np.inf + + Returns + ------- + fill_value : positive infinity value corresponding to this dtype. + """ + if issubclass(dtype.type, np.floating): + return -np.inf + + if issubclass(dtype.type, np.integer): + return np.iinfo(dtype).min if min_for_int else -np.inf + if issubclass(dtype.type, np.complexfloating): + return -np.inf - 1j * np.inf + + return NINF + + +def is_datetime_like(dtype): + """Check if a dtype is a subclass of the numpy datetime types""" + return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64) + + +def result_type( + *arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike, +) -> np.dtype: + """Like np.result_type, but with type promotion rules matching pandas. + + Examples of changed behavior: + number + string -> object (not string) + bytes + unicode -> object (not unicode) + + Parameters + ---------- + *arrays_and_dtypes : list of arrays and dtypes + The dtype is extracted from both numpy and dask arrays. + + Returns + ------- + numpy.dtype for the result. + """ + types = {np.result_type(t).type for t in arrays_and_dtypes} + + for left, right in PROMOTE_TO_OBJECT: + if any(issubclass(t, left) for t in types) and any( + issubclass(t, right) for t in types + ): + return np.dtype(object) + + return np.result_type(*arrays_and_dtypes) diff --git a/xarray/namedarray/pycompat.py b/xarray/namedarray/pycompat.py new file mode 100644 index 00000000000..8aff7894244 --- /dev/null +++ b/xarray/namedarray/pycompat.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from importlib import import_module +from types import ModuleType +from typing import TYPE_CHECKING, Any, Literal + +import numpy as np +from packaging.version import Version + +from xarray.namedarray.utils import is_duck_array, is_scalar, module_available + +integer_types = (int, np.integer) + +if TYPE_CHECKING: + ModType = Literal["dask", "pint", "cupy", "sparse", "cubed"] + DuckArrayTypes = tuple[type[Any], ...] # TODO: improve this? maybe Generic + + +class DuckArrayModule: + """ + Solely for internal isinstance and version checks. + + Motivated by having to only import pint when required (as pint currently imports xarray) + https://github.com/pydata/xarray/pull/5561#discussion_r664815718 + """ + + module: ModuleType | None + version: Version + type: DuckArrayTypes + available: bool + + def __init__(self, mod: ModType) -> None: + duck_array_module: ModuleType | None + duck_array_version: Version + duck_array_type: DuckArrayTypes + try: + duck_array_module = import_module(mod) + duck_array_version = Version(duck_array_module.__version__) + + if mod == "dask": + duck_array_type = (import_module("dask.array").Array,) + elif mod == "pint": + duck_array_type = (duck_array_module.Quantity,) + elif mod == "cupy": + duck_array_type = (duck_array_module.ndarray,) + elif mod == "sparse": + duck_array_type = (duck_array_module.SparseArray,) + elif mod == "cubed": + duck_array_type = (duck_array_module.Array,) + else: + raise NotImplementedError + + except (ImportError, AttributeError): # pragma: no cover + duck_array_module = None + duck_array_version = Version("0.0.0") + duck_array_type = () + + self.module = duck_array_module + self.version = duck_array_version + self.type = duck_array_type + self.available = duck_array_module is not None + + +_cached_duck_array_modules: dict[ModType, DuckArrayModule] = {} + + +def _get_cached_duck_array_module(mod: ModType) -> DuckArrayModule: + if mod not in _cached_duck_array_modules: + duckmod = DuckArrayModule(mod) + _cached_duck_array_modules[mod] = duckmod + return duckmod + else: + return _cached_duck_array_modules[mod] + + +def array_type(mod: ModType) -> DuckArrayTypes: + """Quick wrapper to get the array class of the module.""" + return _get_cached_duck_array_module(mod).type + + +def mod_version(mod: ModType) -> Version: + """Quick wrapper to get the version of the module.""" + return _get_cached_duck_array_module(mod).version + + +def is_dask_collection(x): + if module_available("dask"): + from dask.base import is_dask_collection + + return is_dask_collection(x) + return False + + +def is_duck_dask_array(x): + return is_duck_array(x) and is_dask_collection(x) + + +def is_chunked_array(x) -> bool: + return is_duck_dask_array(x) or (is_duck_array(x) and hasattr(x, "chunks")) + + +def is_0d_dask_array(x): + return is_duck_dask_array(x) and is_scalar(x) diff --git a/xarray/namedarray/types.py b/xarray/namedarray/types.py new file mode 100644 index 00000000000..cec409c34a4 --- /dev/null +++ b/xarray/namedarray/types.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import typing +from collections.abc import Hashable, Iterable + +import numpy as np + +if typing.TYPE_CHECKING: + from xarray.namedarray.core import NamedArray + + +T_NamedArray = typing.TypeVar("T_NamedArray", bound="NamedArray") +DimsInput = typing.Union[str, Iterable[Hashable]] +Dims = tuple[Hashable, ...] + + +# temporary placeholder for indicating an array api compliant type. +# hopefully in the future we can narrow this down more +T_DuckArray = typing.TypeVar("T_DuckArray", bound=typing.Any) + +ScalarOrArray = typing.Union[np.generic, np.ndarray] +NamedArrayCompatible = typing.Union[T_NamedArray, ScalarOrArray] diff --git a/xarray/namedarray/utils.py b/xarray/namedarray/utils.py index 1495e111d85..66a86ada921 100644 --- a/xarray/namedarray/utils.py +++ b/xarray/namedarray/utils.py @@ -3,6 +3,7 @@ import importlib import sys import typing +from collections.abc import Hashable, Iterable import numpy as np @@ -12,9 +13,7 @@ else: from typing_extensions import TypeGuard -# temporary placeholder for indicating an array api compliant type. -# hopefully in the future we can narrow this down more -T_DuckArray = typing.TypeVar("T_DuckArray", bound=typing.Any) + from xarray.namedarray.types import T_DuckArray def module_available(module: str) -> bool: @@ -66,3 +65,82 @@ def to_0d_object_array(value: typing.Any) -> np.ndarray: result = np.empty((), dtype=object) result[()] = value return result + + +def _is_scalar(value, include_0d): + # TODO: figure out if the following is needed + # from xarray.core.variable import NON_NUMPY_SUPPORTED_ARRAY_TYPES + NON_NUMPY_SUPPORTED_ARRAY_TYPES = () + + if include_0d: + include_0d = getattr(value, "ndim", None) == 0 + return ( + include_0d + or isinstance(value, (str, bytes)) + or not ( + isinstance(value, (Iterable,) + NON_NUMPY_SUPPORTED_ARRAY_TYPES) + or is_duck_array(value) + ) + ) + + +# See GH5624, this is a convoluted way to allow type-checking to use `TypeGuard` without +# requiring typing_extensions as a required dependency to _run_ the code (it is required +# to type-check). +try: + if sys.version_info >= (3, 10): + from typing import TypeGuard + else: + from typing_extensions import TypeGuard +except ImportError: + if typing.TYPE_CHECKING: + raise + else: + + def is_scalar(value: typing.Any, include_0d: bool = True) -> bool: + """Whether to treat a value as a scalar. + + Any non-iterable, string, or 0-D array + """ + return _is_scalar(value, include_0d) + +else: + + def is_scalar(value: typing.Any, include_0d: bool = True) -> TypeGuard[Hashable]: + """Whether to treat a value as a scalar. + + Any non-iterable, string, or 0-D array + """ + return _is_scalar(value, include_0d) + + +def is_valid_numpy_dtype(dtype: typing.Any) -> bool: + try: + np.dtype(dtype) + except (TypeError, ValueError): + return False + else: + return True + + +class ReprObject: + """Object that prints as the given value, for use with sentinel values.""" + + __slots__ = ("_value",) + + def __init__(self, value: str): + self._value = value + + def __repr__(self) -> str: + return self._value + + def __eq__(self, other) -> bool: + return self._value == other._value if isinstance(other, ReprObject) else False + + def __hash__(self) -> int: + return hash((type(self), self._value)) + + def __dask_tokenize__(self): + from dask.base import normalize_token + + return normalize_token((type(self), self._value)) diff --git a/xarray/tests/test_namedarray.py b/xarray/tests/test_namedarray.py index 0871a0c6fb9..6f4494f12d4 100644 --- a/xarray/tests/test_namedarray.py +++ b/xarray/tests/test_namedarray.py @@ -3,7 +3,7 @@ import xarray as xr from xarray.namedarray.core import NamedArray, as_compatible_data -from xarray.namedarray.utils import T_DuckArray +from xarray.namedarray.types import T_DuckArray @pytest.fixture diff --git a/xarray/tests/test_namedarray_utils.py b/xarray/tests/test_namedarray_utils.py new file mode 100644 index 00000000000..4db34d3f11d --- /dev/null +++ b/xarray/tests/test_namedarray_utils.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from collections.abc import Hashable + +from xarray.namedarray import utils + + +def test_repr_object(): + obj = utils.ReprObject("foo") + assert repr(obj) == "foo" + assert isinstance(obj, Hashable) + assert not isinstance(obj, str) + + +def test_repr_object_magic_methods(): + o1 = utils.ReprObject("foo") + o2 = utils.ReprObject("foo") + o3 = utils.ReprObject("bar") + o4 = "foo" + assert o1 == o2 + assert o1 != o3 + assert o1 != o4 + assert hash(o1) == hash(o2) + assert hash(o1) != hash(o3) + assert hash(o1) != hash(o4) diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py index 36f62fad71f..8a86e1a40d8 100644 --- a/xarray/tests/test_utils.py +++ b/xarray/tests/test_utils.py @@ -135,26 +135,6 @@ def test_frozen(self): ) -def test_repr_object(): - obj = utils.ReprObject("foo") - assert repr(obj) == "foo" - assert isinstance(obj, Hashable) - assert not isinstance(obj, str) - - -def test_repr_object_magic_methods(): - o1 = utils.ReprObject("foo") - o2 = utils.ReprObject("foo") - o3 = utils.ReprObject("bar") - o4 = "foo" - assert o1 == o2 - assert o1 != o3 - assert o1 != o4 - assert hash(o1) == hash(o2) - assert hash(o1) != hash(o3) - assert hash(o1) != hash(o4) - - def test_is_remote_uri(): assert utils.is_remote_uri("http://example.com") assert utils.is_remote_uri("https://example.com")