Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[namedarray] split .set_dims() into .expand_dims() and broadcast_to() #8380

Merged
merged 71 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
1224b54
add `.set_dims()`, `.transpose()` and `.T` to namedarray
andersy005 Oct 26, 2023
409c5de
more typying fixes
andersy005 Oct 26, 2023
43e10d8
more typing fixes
andersy005 Oct 27, 2023
d6c7758
override set_dims for IndexVariable
andersy005 Oct 27, 2023
a920e11
fix dims
andersy005 Oct 27, 2023
e765d7d
split `.set_dims()` into `.expand_dims()` and `broadcast_to()`
andersy005 Oct 27, 2023
00504f4
more typing fixes
andersy005 Oct 27, 2023
5a06dec
update whats-new
andersy005 Oct 27, 2023
ec17489
update tests
andersy005 Oct 27, 2023
a245021
doc fixes
andersy005 Oct 27, 2023
447f226
update whats-new
andersy005 Oct 27, 2023
9ad56a9
keep `.set_dims()` on `Variable()`
andersy005 Oct 27, 2023
a5918c3
update docs
andersy005 Oct 27, 2023
8091265
revert to set_dims
andersy005 Oct 27, 2023
2b93273
revert to .set_dims on Variable
andersy005 Oct 27, 2023
7ea1fb3
Update xarray/namedarray/core.py
andersy005 Oct 27, 2023
456d57c
restore .transpose on variable
andersy005 Oct 27, 2023
3f458c8
revert to set_dims in Variable
andersy005 Oct 28, 2023
c641ee6
Merge branch 'main' into add-set-dims
andersy005 Oct 28, 2023
aae2861
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 28, 2023
d6240de
fix docstring
andersy005 Oct 28, 2023
1364345
update test_namedarray
andersy005 Oct 28, 2023
b793f74
update tests
andersy005 Oct 28, 2023
b15971f
Merge branch 'main' into add-set-dims
andersy005 Nov 9, 2023
cc67b63
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 9, 2023
79119ce
Merge branch 'main' into add-set-dims
andersy005 Nov 20, 2023
f156744
Merge branch 'main' into add-set-dims
andersy005 Nov 28, 2023
ab6262b
Apply suggestions from code review
andersy005 Nov 28, 2023
6fbf90d
Merge branch 'main' into add-set-dims
andersy005 Nov 28, 2023
069c353
fix formatting issue
andersy005 Nov 28, 2023
5e1af7a
fix tests
andersy005 Nov 28, 2023
2499c93
update expand_dims
andersy005 Nov 30, 2023
3e4d8fa
update tests
andersy005 Nov 30, 2023
64b674a
Merge branch 'main' into add-set-dims
andersy005 Nov 30, 2023
ff01bce
update tests
andersy005 Nov 30, 2023
2c4b2b6
remove unnecessary guard conditions
andersy005 Nov 30, 2023
c0aefaa
Update type hints in NamedArray class and test cases
andersy005 Nov 30, 2023
e2cce05
Refactor NamedArray T property to handle non-2D arrays
andersy005 Nov 30, 2023
3633a2e
Reverse the order of dimensions in x.T
andersy005 Nov 30, 2023
326dad4
Refactor broadcasting and dimension expansion in NamedArray
andersy005 Dec 1, 2023
1c562a9
update docstring
andersy005 Dec 1, 2023
a7fd5c7
add todo item
andersy005 Dec 1, 2023
ff1a6fe
Merge branch 'main' into add-set-dims
andersy005 Dec 1, 2023
82e89c0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 1, 2023
47121ef
Merge branch 'main' into add-set-dims
andersy005 Dec 1, 2023
31acb3b
use comprehension
andersy005 Dec 1, 2023
6fdbad6
use dim
andersy005 Dec 1, 2023
f1feb9f
Merge branch 'main' into add-set-dims
andersy005 Dec 1, 2023
401b6d5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 1, 2023
c769639
fix imports
andersy005 Dec 1, 2023
dd64dbf
Merge branch 'main' into add-set-dims
andersy005 Dec 1, 2023
06cba49
Merge branch 'main' into add-set-dims
andersy005 Dec 5, 2023
959d97c
formatting only
andersy005 Dec 5, 2023
7be6a2d
Apply suggestions from code review
andersy005 Dec 5, 2023
3dfbce4
Merge branch 'main' into add-set-dims
andersy005 Dec 7, 2023
ea24613
[skip-rtd] fix indentation
andersy005 Dec 7, 2023
3fba906
Merge branch 'main' into add-set-dims
andersy005 Jan 23, 2024
f2a5989
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 23, 2024
8f494d3
Merge branch 'main' into add-set-dims
andersy005 Jan 24, 2024
5f7e5c5
refactor expand_dims to simplify API
andersy005 Jan 25, 2024
ce35128
Merge branch 'main' into add-set-dims
andersy005 Jan 25, 2024
0fb4443
fix type hint for `dim` parameter in test.
andersy005 Jan 25, 2024
a831c14
fix typing issues
andersy005 Jan 25, 2024
d880621
Merge branch 'main' into add-set-dims
andersy005 Jan 25, 2024
1524519
fix UnboundLocalError: local variable 'flattened_dims' referenced bef…
andersy005 Jan 26, 2024
e6af928
fix type hint
andersy005 Jan 26, 2024
7c3de30
Merge branch 'main' into add-set-dims
andersy005 Jan 26, 2024
a6ad9b8
ignore typing
andersy005 Jan 26, 2024
fd1c0d3
update whats-new
andersy005 Jan 26, 2024
6ec1170
Merge branch 'main' into add-set-dims
andersy005 Jan 28, 2024
e84abf9
adjust the `broadcast_to` method to prohibit adding new dimensions, a…
andersy005 Jan 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,7 @@ def _set_dims(var):
# ignore dim not in var.dims
var_dims_map[dim] = var.shape[var.dims.index(dim)]

return var.set_dims(var_dims_map)
return var.expand_dims(var_dims_map)

def _broadcast_array(array: T_DataArray) -> T_DataArray:
data = _set_dims(array.variable)
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ def ensure_common_dims(vars, concat_dim_lengths):
for var, dim_len in zip(vars, concat_dim_lengths):
if var.dims != common_dims:
common_shape = tuple(dims_sizes.get(d, dim_len) for d in common_dims)
var = var.set_dims(common_dims, common_shape)
var = var.expand_dims(common_dims, common_shape)
yield var

# get the indexes to concatenate together, create a PandasIndex
Expand All @@ -567,7 +567,7 @@ def get_indexes(name):
elif name == dim:
var = ds._variables[name]
if not var.dims:
data = var.set_dims(dim).values
data = var.expand_dims(dim).values
yield PandasIndex(data, dim, coord_dtype=var.dtype)

# create concatenation index, needed for later reindexing
Expand Down
13 changes: 3 additions & 10 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,7 @@
import warnings
from collections.abc import Hashable, Iterable, Mapping, MutableMapping, Sequence
from os import PathLike
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Literal,
NoReturn,
overload,
)
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, NoReturn, overload

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -64,6 +56,7 @@
as_compatible_data,
as_variable,
)
from xarray.namedarray.utils import infix_dims
from xarray.plot.accessor import DataArrayPlotAccessor
from xarray.plot.utils import _get_units_from_attrs
from xarray.util.deprecation_helpers import _deprecate_positional_args
Expand Down Expand Up @@ -2994,7 +2987,7 @@ def transpose(
Dataset.transpose
"""
if dims:
dims = tuple(utils.infix_dims(dims, self.dims, missing_dims))
dims = tuple(infix_dims(dims, self.dims, missing_dims))
variable = self.variable.transpose(*dims)
if transpose_coords:
coords: dict[Hashable, Variable] = {}
Expand Down
17 changes: 8 additions & 9 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,6 @@
decode_numpy_dict_values,
drop_dims_from_indexers,
either_dict_or_kwargs,
infix_dims,
is_dict_like,
is_scalar,
maybe_wrap_array,
)
Expand All @@ -122,6 +120,7 @@
broadcast_variables,
calculate_dimensions,
)
from xarray.namedarray.utils import infix_dims, is_dict_like
from xarray.plot.accessor import DatasetPlotAccessor
from xarray.util.deprecation_helpers import _deprecate_positional_args

Expand Down Expand Up @@ -4615,12 +4614,12 @@ def expand_dims(
all_dims = list(zip(v.dims, v.shape))
for d, c in zip_axis_dim:
all_dims.insert(d, c)
variables[k] = v.set_dims(dict(all_dims))
variables[k] = v.expand_dims(dict(all_dims))
else:
if k not in variables:
# If dims includes a label of a non-dimension coordinate,
# it will be promoted to a 1D coordinate with a single value.
index, index_vars = create_default_index_implicit(v.set_dims(k))
index, index_vars = create_default_index_implicit(v.expand_dims(k))
indexes[k] = index
variables.update(index_vars)

Expand Down Expand Up @@ -5143,8 +5142,8 @@ def _stack_once(
add_dims = [d for d in dims if d not in var.dims]
vdims = list(var.dims) + add_dims
shape = [self.dims[d] for d in vdims]
exp_var = var.set_dims(vdims, shape)
stacked_var = exp_var.stack(**{new_dim: dims})
exp_var = var.expand_dims(vdims, shape)
stacked_var = exp_var.stack(**{new_dim: dims}) # type: ignore
new_variables[name] = stacked_var
stacked_var_names.append(name)
else:
Expand Down Expand Up @@ -7091,7 +7090,7 @@ def to_pandas(self) -> pd.Series | pd.DataFrame:
def _to_dataframe(self, ordered_dims: Mapping[Any, int]):
columns = [k for k in self.variables if k not in self.dims]
data = [
self._variables[k].set_dims(ordered_dims).values.reshape(-1)
self._variables[k].expand_dims(ordered_dims).values.reshape(-1)
for k in columns
]
index = self.coords.to_index([*ordered_dims])
Expand Down Expand Up @@ -7337,8 +7336,8 @@ def to_dask_dataframe(
var = var.chunk()

# Broadcast then flatten the array:
var_new_dims = var.set_dims(ordered_dims).chunk(ds_chunks)
dask_array = var_new_dims._data.reshape(-1)
var_new_dims = var.expand_dims(ordered_dims).chunk(ds_chunks)
dask_array = var_new_dims._data.reshape(-1) # type: ignore

series = dd.from_dask_array(dask_array, columns=name, meta=df_meta)
series_list.append(series)
Expand Down
11 changes: 2 additions & 9 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,7 @@
from abc import ABC, abstractmethod
from collections.abc import Hashable, Iterator, Mapping, Sequence
from dataclasses import dataclass, field
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Literal,
Union,
)
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -1052,7 +1045,7 @@ def _flox_reduce(
# broadcast and restore non-numeric data variables (backcompat)
for name, var in non_numeric.items():
if all(d not in var.dims for d in parsed_dim):
result[name] = var.variable.set_dims(
result[name] = var.variable.expand_dims(
(grouper.name,) + var.dims,
(result.sizes[grouper.name],) + var.shape,
)
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def unique_variable(

if compat == "broadcast_equals":
dim_lengths = broadcast_dimension_size(variables)
out = out.set_dims(dim_lengths)
out = out.expand_dims(dim_lengths)

if compat == "no_conflicts":
combine_method = "fillna"
Expand Down
30 changes: 0 additions & 30 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,36 +789,6 @@ def __len__(self) -> int:
return len(self._data) - num_hidden


def infix_dims(
dims_supplied: Collection,
dims_all: Collection,
missing_dims: ErrorOptionsWithWarn = "raise",
) -> Iterator:
"""
Resolves a supplied list containing an ellipsis representing other items, to
a generator with the 'realized' list of all items
"""
if ... in dims_supplied:
if len(set(dims_all)) != len(dims_all):
raise ValueError("Cannot use ellipsis with repeated dims")
if list(dims_supplied).count(...) > 1:
raise ValueError("More than one ellipsis supplied")
other_dims = [d for d in dims_all if d not in dims_supplied]
existing_dims = drop_missing_dims(dims_supplied, dims_all, missing_dims)
for d in existing_dims:
if d is ...:
yield from other_dims
else:
yield d
else:
existing_dims = drop_missing_dims(dims_supplied, dims_all, missing_dims)
if set(existing_dims) ^ set(dims_all):
raise ValueError(
f"{dims_supplied} must be a permuted list of {dims_all}, unless `...` is included"
)
yield from existing_dims


def get_temp_dimname(dims: Container[Hashable], new_dim: Hashable) -> Hashable:
"""Get an new dimension name based on new_dim, that is not used in dims.
If the same name exists, we add an underscore(s) in the head.
Expand Down
116 changes: 11 additions & 105 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
drop_dims_from_indexers,
either_dict_or_kwargs,
ensure_us_time_resolution,
infix_dims,
is_duck_array,
maybe_coerce_to_str,
)
Expand Down Expand Up @@ -878,7 +877,7 @@ def __setitem__(self, key, value):
else:
value = Variable(dims[-value.ndim :], value)
# broadcast to become assignable
value = value.set_dims(dims).data
value = value.expand_dims(dims).data

if new_order:
value = duck_array_ops.asarray(value)
Expand Down Expand Up @@ -1382,107 +1381,6 @@ def roll(self, shifts=None, **shifts_kwargs):
result = result._roll_one_dim(dim, count)
return result

def transpose(
self,
*dims: Hashable | ellipsis,
missing_dims: ErrorOptionsWithWarn = "raise",
) -> Self:
"""Return a new Variable object with transposed dimensions.

Parameters
----------
*dims : Hashable, optional
By default, reverse the dimensions. Otherwise, reorder the
dimensions to this order.
missing_dims : {"raise", "warn", "ignore"}, default: "raise"
What to do if dimensions that should be selected from are not present in the
Variable:
- "raise": raise an exception
- "warn": raise a warning, and ignore the missing dimensions
- "ignore": ignore the missing dimensions

Returns
-------
transposed : Variable
The returned object has transposed data and dimensions with the
same attributes as the original.

Notes
-----
This operation returns a view of this variable's data. It is
lazy for dask-backed Variables but not for numpy-backed Variables.

See Also
--------
numpy.transpose
"""
if len(dims) == 0:
dims = self.dims[::-1]
else:
dims = tuple(infix_dims(dims, self.dims, missing_dims))

if len(dims) < 2 or dims == self.dims:
# no need to transpose if only one dimension
# or dims are in same order
return self.copy(deep=False)

axes = self.get_axis_num(dims)
data = as_indexable(self._data).transpose(axes)
return self._replace(dims=dims, data=data)

@property
def T(self) -> Self:
return self.transpose()

def set_dims(self, dims, shape=None):
"""Return a new variable with given set of dimensions.
This method might be used to attach new dimension(s) to variable.

When possible, this operation does not copy this variable's data.

Parameters
----------
dims : str or sequence of str or dict
Dimensions to include on the new variable. If a dict, values are
used to provide the sizes of new dimensions; otherwise, new
dimensions are inserted with length 1.

Returns
-------
Variable
"""
if isinstance(dims, str):
dims = [dims]

if shape is None and utils.is_dict_like(dims):
shape = dims.values()

missing_dims = set(self.dims) - set(dims)
if missing_dims:
raise ValueError(
f"new dimensions {dims!r} must be a superset of "
f"existing dimensions {self.dims!r}"
)

self_dims = set(self.dims)
expanded_dims = tuple(d for d in dims if d not in self_dims) + self.dims

if self.dims == expanded_dims:
# don't use broadcast_to unless necessary so the result remains
# writeable if possible
expanded_data = self.data
elif shape is not None:
dims_map = dict(zip(dims, shape))
tmp_shape = tuple(dims_map[d] for d in expanded_dims)
expanded_data = duck_array_ops.broadcast_to(self.data, tmp_shape)
else:
expanded_data = self.data[(None,) * (len(expanded_dims) - self.ndim)]

expanded_var = Variable(
andersy005 marked this conversation as resolved.
Show resolved Hide resolved
expanded_dims, expanded_data, self._attrs, self._encoding, fastpath=True
)
return expanded_var.transpose(*dims)

def _stack_once(self, dims: list[Hashable], new_dim: Hashable):
if not set(dims) <= set(self.dims):
raise ValueError(f"invalid existing dimensions: {dims}")
Expand Down Expand Up @@ -2851,6 +2749,11 @@ def _inplace_binary_op(self, other, f):
"Values of an IndexVariable are immutable and can not be modified inplace"
)

def _create_expanded_obj(self, expanded_data, expanded_dims) -> Variable: # type: ignore
return Variable(
expanded_dims, expanded_data, self._attrs, self._encoding, fastpath=True
)


def _unified_dims(variables):
# validate dimensions
Expand Down Expand Up @@ -2880,7 +2783,9 @@ def _broadcast_compat_variables(*variables):
dimensions of size 1 instead of the size of the broadcast dimension.
"""
dims = tuple(_unified_dims(variables))
return tuple(var.set_dims(dims) if var.dims != dims else var for var in variables)
return tuple(
var.expand_dims(dims) if var.dims != dims else var for var in variables
)


def broadcast_variables(*variables: Variable) -> tuple[Variable, ...]:
Expand All @@ -2896,7 +2801,8 @@ def broadcast_variables(*variables: Variable) -> tuple[Variable, ...]:
dims_map = _unified_dims(variables)
dims_tuple = tuple(dims_map)
return tuple(
var.set_dims(dims_map) if var.dims != dims_tuple else var for var in variables
var.expand_dims(dims_map) if var.dims != dims_tuple else var
for var in variables
)


Expand Down
4 changes: 4 additions & 0 deletions xarray/namedarray/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
TYPE_CHECKING,
Any,
Callable,
Literal,
Protocol,
SupportsIndex,
TypeVar,
Expand Down Expand Up @@ -263,3 +264,6 @@ def todense(self) -> NDArray[_ScalarType_co]:

# NamedArray can most likely use both __array_function__ and __array_namespace__:
_sparsearrayfunction_or_api = (_sparsearrayfunction, _sparsearrayapi)

ErrorOptions = Literal["raise", "ignore"]
ErrorOptionsWithWarn = Literal["raise", "warn", "ignore"]
Loading
Loading