From 583dd95b2d4e37ca334d8e9bd37dcc0745767275 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 3 Nov 2023 01:00:54 +0100 Subject: [PATCH 01/13] Update _array_api.py --- xarray/namedarray/_array_api.py | 49 +++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index e205c4d4efe..fb74fb6cfeb 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -8,11 +8,13 @@ from xarray.namedarray._typing import ( _arrayapi, + _Dim, _DType, _ScalarType, _ShapeType, _SupportsImag, _SupportsReal, + Default, ) from xarray.namedarray.core import NamedArray @@ -144,3 +146,50 @@ def real( xp = _get_data_namespace(x) out = x._new(data=xp.real(x._data)) return out + + +# %% Manipulation functions +def expand_dims( + x: NamedArray[Any, _DType], + /, + *, + dim: _Dim | Default = _default, + axis: _Axis = 0, +) -> NamedArray[Any, _DType]: + """ + Expands the shape of an array by inserting a new dimension of size one at the + position specified by dims. + + Parameters + ---------- + x : + Array to expand. + dim : + Dimension name. New dimension will be stored in the 0 position. + axis : + Axis position (zero-based). If x has rank (i.e, number of dimensions) N, + a valid axis must reside on the closed-interval [-N-1, N]. If provided a + negative axis, the axis position at which to insert a singleton dimension + must be computed as N + axis + 1. Hence, if provided -1, the resolved axis + position must be N (i.e., a singleton dimension must be appended to the + input array x). If provided -N-1, the resolved axis position must be 0 + (i.e., a singleton dimension must be prepended to the input array x). + + Returns + ------- + out : + An expanded output array having the same data type as x. + + Examples + -------- + >>> x = NamedArray(("x", "y"), nxp.asarray([[1.0, 2.0], [3.0, 4.0]])) + >>> expand_dims(x, dims="new_dim") + + Array([[[1., 2.], + [3., 4.]]], dtype=float64) + """ + xp = _get_data_namespace(x) + d = list(x.dims) + d.insert(axis, dim) + out = x._new(dims=tuple(d), data=xp.expand_dims(x._data, axis=axis)) + return out From eb313fd9b161a0ff7dd967438a7c40392780836a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 Nov 2023 00:01:57 +0000 Subject: [PATCH 02/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index fb74fb6cfeb..13919f1f4d2 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -7,6 +7,7 @@ import numpy as np from xarray.namedarray._typing import ( + Default, _arrayapi, _Dim, _DType, @@ -14,7 +15,6 @@ _ShapeType, _SupportsImag, _SupportsReal, - Default, ) from xarray.namedarray.core import NamedArray From 23fcb2d81ae94212bce4632c7b8e2c8c36829a69 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 3 Nov 2023 16:21:10 +0100 Subject: [PATCH 03/13] fixes --- xarray/namedarray/_array_api.py | 4 ++++ xarray/namedarray/_typing.py | 10 ++++++++++ 2 files changed, 14 insertions(+) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index 13919f1f4d2..9fb4ca46fb1 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -9,7 +9,11 @@ from xarray.namedarray._typing import ( Default, _arrayapi, + _Axis, + _AxisLike, + _default, _Dim, + _Dims, _DType, _ScalarType, _ShapeType, diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 0b972e19539..de60b8ca7cf 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -1,10 +1,12 @@ from __future__ import annotations from collections.abc import Hashable, Iterable, Mapping, Sequence +from enum import Enum from types import ModuleType from typing import ( Any, Callable, + Final, Protocol, SupportsIndex, TypeVar, @@ -15,6 +17,14 @@ import numpy as np + +# Singleton type, as per https://github.com/python/typing/pull/240 +class Default(Enum): + token: Final = 0 + + +_default = Default.token + # https://stackoverflow.com/questions/74633074/how-to-type-hint-a-generic-numpy-array _T = TypeVar("_T") _T_co = TypeVar("_T_co", covariant=True) From 530aedf4eebcc1e120abde2032bb45a5b60b54f1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 Nov 2023 15:21:52 +0000 Subject: [PATCH 04/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index 9fb4ca46fb1..d6aeebb0632 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -10,10 +10,8 @@ Default, _arrayapi, _Axis, - _AxisLike, _default, _Dim, - _Dims, _DType, _ScalarType, _ShapeType, From 7205a32da5cb292eed754126b021b90a42c28b6e Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 3 Nov 2023 16:22:59 +0100 Subject: [PATCH 05/13] more --- xarray/namedarray/core.py | 3 ++- xarray/namedarray/utils.py | 15 +-------------- 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 2fef1cad3db..126e0b8b5e3 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -25,6 +25,7 @@ _arrayapi, _arrayfunction_or_api, _chunkedarray, + _default, _dtype, _DType_co, _ScalarType_co, @@ -33,7 +34,7 @@ _SupportsImag, _SupportsReal, ) -from xarray.namedarray.utils import _default, is_duck_dask_array, to_0d_object_array +from xarray.namedarray.utils import is_duck_dask_array, to_0d_object_array if TYPE_CHECKING: from numpy.typing import ArrayLike, NDArray diff --git a/xarray/namedarray/utils.py b/xarray/namedarray/utils.py index 03eb0134231..4bd20931189 100644 --- a/xarray/namedarray/utils.py +++ b/xarray/namedarray/utils.py @@ -2,12 +2,7 @@ import sys from collections.abc import Hashable -from enum import Enum -from typing import ( - TYPE_CHECKING, - Any, - Final, -) +from typing import TYPE_CHECKING, Any import numpy as np @@ -31,14 +26,6 @@ DaskCollection: Any = NDArray # type: ignore -# Singleton type, as per https://github.com/python/typing/pull/240 -class Default(Enum): - token: Final = 0 - - -_default = Default.token - - def module_available(module: str) -> bool: """Checks whether a module is installed without importing it. From a204fe5edd89b9c392934b3910dd7a697b2436d4 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 3 Nov 2023 16:42:24 +0100 Subject: [PATCH 06/13] fixes --- xarray/namedarray/_array_api.py | 31 +++++++++++++++++++------------ xarray/namedarray/_typing.py | 4 ++++ 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index d6aeebb0632..7f1fce08c9e 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -9,8 +9,8 @@ from xarray.namedarray._typing import ( Default, _arrayapi, - _Axis, _default, + _Axis, _Dim, _DType, _ScalarType, @@ -167,15 +167,9 @@ def expand_dims( x : Array to expand. dim : - Dimension name. New dimension will be stored in the 0 position. + Dimension name. New dimension will be stored in the axis position. axis : - Axis position (zero-based). If x has rank (i.e, number of dimensions) N, - a valid axis must reside on the closed-interval [-N-1, N]. If provided a - negative axis, the axis position at which to insert a singleton dimension - must be computed as N + axis + 1. Hence, if provided -1, the resolved axis - position must be N (i.e., a singleton dimension must be appended to the - input array x). If provided -N-1, the resolved axis position must be 0 - (i.e., a singleton dimension must be prepended to the input array x). + (Not recommended) Axis position (zero-based). Default is 0. Returns ------- @@ -185,13 +179,26 @@ def expand_dims( Examples -------- >>> x = NamedArray(("x", "y"), nxp.asarray([[1.0, 2.0], [3.0, 4.0]])) - >>> expand_dims(x, dims="new_dim") - + >>> expand_dims(x) + + Array([[[1., 2.], + [3., 4.]]], dtype=float64) + >>> expand_dims(x, dim="z") + Array([[[1., 2.], [3., 4.]]], dtype=float64) """ xp = _get_data_namespace(x) - d = list(x.dims) + dims = x.dims + if dim is _default: + dim = f"dim_{len(dims)}" + d = list(dims) d.insert(axis, dim) out = x._new(dims=tuple(d), data=xp.expand_dims(x._data, axis=axis)) return out + + +if __name__ == "__main__": + import doctest + + doctest.testmod() diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index de60b8ca7cf..670a2076eb1 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -59,6 +59,10 @@ def dtype(self) -> _DType_co: _ShapeType = TypeVar("_ShapeType", bound=Any) _ShapeType_co = TypeVar("_ShapeType_co", bound=Any, covariant=True) +_Axis = int +_Axes = tuple[_Axis, ...] +_AxisLike = Union[_Axis, _Axes] + _Chunks = tuple[_Shape, ...] _Dim = Hashable From 11fee3995fa4d1666ec65a138388d279a39635b2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 Nov 2023 15:45:03 +0000 Subject: [PATCH 07/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index 7f1fce08c9e..738b21664d8 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -9,8 +9,8 @@ from xarray.namedarray._typing import ( Default, _arrayapi, - _default, _Axis, + _default, _Dim, _DType, _ScalarType, From ae921480b2a1b72872d37e9c4c3307d38bc678b2 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 3 Nov 2023 16:51:03 +0100 Subject: [PATCH 08/13] Update test_namedarray.py --- xarray/tests/test_namedarray.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/xarray/tests/test_namedarray.py b/xarray/tests/test_namedarray.py index 448e8cf819a..0603fd49ccc 100644 --- a/xarray/tests/test_namedarray.py +++ b/xarray/tests/test_namedarray.py @@ -9,9 +9,13 @@ import pytest from xarray.core.indexing import ExplicitlyIndexed -from xarray.namedarray._typing import _arrayfunction_or_api, _DType_co, _ShapeType_co +from xarray.namedarray._typing import ( + _default, + _arrayfunction_or_api, + _DType_co, + _ShapeType_co, +) from xarray.namedarray.core import NamedArray, from_array -from xarray.namedarray.utils import _default if TYPE_CHECKING: from types import ModuleType @@ -24,8 +28,8 @@ _DType, _Shape, duckarray, + Default, ) - from xarray.namedarray.utils import Default class CustomArrayBase(Generic[_ShapeType_co, _DType_co]): From 5acf75aa65c17676545beaaf9d09c25de6bd72eb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 Nov 2023 15:52:00 +0000 Subject: [PATCH 09/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/tests/test_namedarray.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_namedarray.py b/xarray/tests/test_namedarray.py index 0603fd49ccc..82110e072b2 100644 --- a/xarray/tests/test_namedarray.py +++ b/xarray/tests/test_namedarray.py @@ -10,8 +10,8 @@ from xarray.core.indexing import ExplicitlyIndexed from xarray.namedarray._typing import ( - _default, _arrayfunction_or_api, + _default, _DType_co, _ShapeType_co, ) @@ -23,12 +23,12 @@ from numpy.typing import ArrayLike, DTypeLike, NDArray from xarray.namedarray._typing import ( + Default, _AttrsLike, _DimsLike, _DType, _Shape, duckarray, - Default, ) From 6e28051d0b9d7ba3411754d6678f2857c3b97159 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 3 Nov 2023 16:55:24 +0100 Subject: [PATCH 10/13] Update _array_api.py --- xarray/namedarray/_array_api.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index 738b21664d8..b5c320e0b96 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -196,9 +196,3 @@ def expand_dims( d.insert(axis, dim) out = x._new(dims=tuple(d), data=xp.expand_dims(x._data, axis=axis)) return out - - -if __name__ == "__main__": - import doctest - - doctest.testmod() From 66f82b171437429d8956104bcf1cf36a0c84a0a3 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 3 Nov 2023 17:07:33 +0100 Subject: [PATCH 11/13] Update variable.py --- xarray/core/variable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index db109a40454..d1b387eb8b4 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -2599,7 +2599,7 @@ def _as_sparse(self, sparse_format=_default, fill_value=_default) -> Variable: """ Use sparse-array as backend. """ - from xarray.namedarray.utils import _default as _default_named + from xarray.namedarray._typing import _default as _default_named if sparse_format is _default: sparse_format = _default_named From 3deb66c17a1bb9b641049767cb8395074a6b6591 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 3 Nov 2023 17:12:52 +0100 Subject: [PATCH 12/13] Update core.py --- xarray/namedarray/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 126e0b8b5e3..00a18dfb353 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -53,8 +53,8 @@ _Shape, _ShapeType, duckarray, + Default, ) - from xarray.namedarray.utils import Default try: from dask.typing import ( From 84e0f143753422df09ade814162c9971e7ee06d8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 Nov 2023 16:15:55 +0000 Subject: [PATCH 13/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index b42a7688d3b..9c208bdefcc 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -41,6 +41,7 @@ from xarray.core.types import Dims from xarray.namedarray._typing import ( + Default, _AttrsLike, _Chunks, _Dim, @@ -52,7 +53,6 @@ _Shape, _ShapeType, duckarray, - Default, ) try: