Skip to content

Commit

Permalink
Deprecate ds.dims returning dict (#8500)
Browse files Browse the repository at this point in the history
* raise FutureWarning

* change some internal instances of ds.dims -> ds.sizes

* improve clarity of which unexpected errors were raised

* whatsnew

* return a class which warns if treated like a Mapping

* fix failing tests

* avoid some warnings in the docs

* silence warning caused by #8491

* fix another warning

* typing of .get

* fix various uses of ds.dims in tests

* fix some warnings

* add test that FutureWarnings are correctly raised

* more fixes to avoid warnings

* update tests to avoid warnings

* yet more fixes to avoid warnings

* also warn in groupby.dims

* change groupby tests to match

* update whatsnew to include groupby deprecation

* filter warning when we actually test ds.dims

* remove error I used for debugging

---------

Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
  • Loading branch information
TomNicholas and dcherian authored Dec 6, 2023
1 parent 3fc0ee5 commit 299abd6
Show file tree
Hide file tree
Showing 20 changed files with 170 additions and 65 deletions.
2 changes: 1 addition & 1 deletion doc/gallery/plot_cartopy_facetgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
transform=ccrs.PlateCarree(), # the data's projection
col="time",
col_wrap=1, # multiplot settings
aspect=ds.dims["lon"] / ds.dims["lat"], # for a sensible figsize
aspect=ds.sizes["lon"] / ds.sizes["lat"], # for a sensible figsize
subplot_kws={"projection": map_proj}, # the plot's projection
)

Expand Down
4 changes: 2 additions & 2 deletions doc/user-guide/interpolation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,8 @@ Let's see how :py:meth:`~xarray.DataArray.interp` works on real data.
axes[0].set_title("Raw data")
# Interpolated data
new_lon = np.linspace(ds.lon[0], ds.lon[-1], ds.dims["lon"] * 4)
new_lat = np.linspace(ds.lat[0], ds.lat[-1], ds.dims["lat"] * 4)
new_lon = np.linspace(ds.lon[0], ds.lon[-1], ds.sizes["lon"] * 4)
new_lat = np.linspace(ds.lat[0], ds.lat[-1], ds.sizes["lat"] * 4)
dsi = ds.interp(lat=new_lat, lon=new_lon)
dsi.air.plot(ax=axes[1])
@savefig interpolation_sample3.png width=8in
Expand Down
9 changes: 4 additions & 5 deletions doc/user-guide/terminology.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ complete examples, please consult the relevant documentation.*
all but one of these degrees of freedom is fixed. We can think of each
dimension axis as having a name, for example the "x dimension". In
xarray, a ``DataArray`` object's *dimensions* are its named dimension
axes, and the name of the ``i``-th dimension is ``arr.dims[i]``. If an
array is created without dimension names, the default dimension names are
``dim_0``, ``dim_1``, and so forth.
axes ``da.dims``, and the name of the ``i``-th dimension is ``da.dims[i]``.
If an array is created without specifying dimension names, the default dimension
names will be ``dim_0``, ``dim_1``, and so forth.

Coordinate
An array that labels a dimension or set of dimensions of another
Expand All @@ -61,8 +61,7 @@ complete examples, please consult the relevant documentation.*
``arr.coords[x]``. A ``DataArray`` can have more coordinates than
dimensions because a single dimension can be labeled by multiple
coordinate arrays. However, only one coordinate array can be a assigned
as a particular dimension's dimension coordinate array. As a
consequence, ``len(arr.dims) <= len(arr.coords)`` in general.
as a particular dimension's dimension coordinate array.

Dimension coordinate
A one-dimensional coordinate array assigned to ``arr`` with both a name
Expand Down
13 changes: 10 additions & 3 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,17 @@ Deprecations
currently ``PendingDeprecationWarning``, which are silenced by default. We'll
convert these to ``DeprecationWarning`` in a future release.
By `Maximilian Roos <https://github.com/max-sixty>`_.
- :py:meth:`Dataset.drop` &
:py:meth:`DataArray.drop` are now deprecated, since pending deprecation for
- Raise a ``FutureWarning`` warning that the type of :py:meth:`Dataset.dims` will be changed
from a mapping of dimension names to lengths to a set of dimension names.
This is to increase consistency with :py:meth:`DataArray.dims`.
To access a mapping of dimension names to lengths please use :py:meth:`Dataset.sizes`.
The same change also applies to `DatasetGroupBy.dims`.
(:issue:`8496`, :pull:`8500`)
By `Tom Nicholas <https://github.com/TomNicholas>`_.
- :py:meth:`Dataset.drop` & :py:meth:`DataArray.drop` are now deprecated, since pending deprecation for
several years. :py:meth:`DataArray.drop_sel` & :py:meth:`DataArray.drop_var`
replace them for labels & variables respectively.
replace them for labels & variables respectively. (:pull:`8497`)
By `Maximilian Roos <https://github.com/max-sixty>`_.

Bug fixes
~~~~~~~~~
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,7 +1167,7 @@ def _dataset_indexer(dim: Hashable) -> DataArray:
cond_wdim = cond.drop_vars(
var for var in cond if dim not in cond[var].dims
)
keepany = cond_wdim.any(dim=(d for d in cond.dims.keys() if d != dim))
keepany = cond_wdim.any(dim=(d for d in cond.dims if d != dim))
return keepany.to_dataarray().any("variable")

_get_indexer = (
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def _calc_concat_over(datasets, dim, dim_names, data_vars: T_DataVars, coords, c
if dim in ds:
ds = ds.set_coords(dim)
concat_over.update(k for k, v in ds.variables.items() if dim in v.dims)
concat_dim_lengths.append(ds.dims.get(dim, 1))
concat_dim_lengths.append(ds.sizes.get(dim, 1))

def process_subset_opt(opt, subset):
if isinstance(opt, str):
Expand Down Expand Up @@ -431,7 +431,7 @@ def _parse_datasets(
variables_order: dict[Hashable, Variable] = {} # variables in order of appearance

for ds in datasets:
dims_sizes.update(ds.dims)
dims_sizes.update(ds.sizes)
all_coord_names.update(ds.coords)
data_vars.update(ds.data_vars)
variables_order.update(ds.variables)
Expand Down
40 changes: 21 additions & 19 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
from xarray.core.utils import (
Default,
Frozen,
FrozenMappingWarningOnValuesAccess,
HybridMappingProxy,
OrderedSet,
_default,
Expand Down Expand Up @@ -778,14 +779,15 @@ def dims(self) -> Frozen[Hashable, int]:
Note that type of this object differs from `DataArray.dims`.
See `Dataset.sizes` and `DataArray.sizes` for consistently named
properties.
properties. This property will be changed to return a type more consistent with
`DataArray.dims` in the future, i.e. a set of dimension names.
See Also
--------
Dataset.sizes
DataArray.dims
"""
return Frozen(self._dims)
return FrozenMappingWarningOnValuesAccess(self._dims)

@property
def sizes(self) -> Frozen[Hashable, int]:
Expand All @@ -800,7 +802,7 @@ def sizes(self) -> Frozen[Hashable, int]:
--------
DataArray.sizes
"""
return self.dims
return Frozen(self._dims)

@property
def dtypes(self) -> Frozen[Hashable, np.dtype]:
Expand Down Expand Up @@ -1411,7 +1413,7 @@ def _copy_listed(self, names: Iterable[Hashable]) -> Self:
variables[name] = self._variables[name]
except KeyError:
ref_name, var_name, var = _get_virtual_variable(
self._variables, name, self.dims
self._variables, name, self.sizes
)
variables[var_name] = var
if ref_name in self._coord_names or ref_name in self.dims:
Expand All @@ -1426,7 +1428,7 @@ def _copy_listed(self, names: Iterable[Hashable]) -> Self:
for v in variables.values():
needed_dims.update(v.dims)

dims = {k: self.dims[k] for k in needed_dims}
dims = {k: self.sizes[k] for k in needed_dims}

# preserves ordering of coordinates
for k in self._variables:
Expand All @@ -1448,7 +1450,7 @@ def _construct_dataarray(self, name: Hashable) -> DataArray:
try:
variable = self._variables[name]
except KeyError:
_, name, variable = _get_virtual_variable(self._variables, name, self.dims)
_, name, variable = _get_virtual_variable(self._variables, name, self.sizes)

needed_dims = set(variable.dims)

Expand All @@ -1475,7 +1477,7 @@ def _item_sources(self) -> Iterable[Mapping[Hashable, Any]]:
yield HybridMappingProxy(keys=self._coord_names, mapping=self.coords)

# virtual coordinates
yield HybridMappingProxy(keys=self.dims, mapping=self)
yield HybridMappingProxy(keys=self.sizes, mapping=self)

def __contains__(self, key: object) -> bool:
"""The 'in' operator will return true or false depending on whether
Expand Down Expand Up @@ -2569,7 +2571,7 @@ def info(self, buf: IO | None = None) -> None:
lines = []
lines.append("xarray.Dataset {")
lines.append("dimensions:")
for name, size in self.dims.items():
for name, size in self.sizes.items():
lines.append(f"\t{name} = {size} ;")
lines.append("\nvariables:")
for name, da in self.variables.items():
Expand Down Expand Up @@ -2697,10 +2699,10 @@ def chunk(
else:
chunks_mapping = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk")

bad_dims = chunks_mapping.keys() - self.dims.keys()
bad_dims = chunks_mapping.keys() - self.sizes.keys()
if bad_dims:
raise ValueError(
f"chunks keys {tuple(bad_dims)} not found in data dimensions {tuple(self.dims)}"
f"chunks keys {tuple(bad_dims)} not found in data dimensions {tuple(self.sizes.keys())}"
)

chunkmanager = guess_chunkmanager(chunked_array_type)
Expand Down Expand Up @@ -3952,7 +3954,7 @@ def maybe_variable(obj, k):
try:
return obj._variables[k]
except KeyError:
return as_variable((k, range(obj.dims[k])))
return as_variable((k, range(obj.sizes[k])))

def _validate_interp_indexer(x, new_x):
# In the case of datetimes, the restrictions placed on indexers
Expand Down Expand Up @@ -4176,7 +4178,7 @@ def _rename_vars(
return variables, coord_names

def _rename_dims(self, name_dict: Mapping[Any, Hashable]) -> dict[Hashable, int]:
return {name_dict.get(k, k): v for k, v in self.dims.items()}
return {name_dict.get(k, k): v for k, v in self.sizes.items()}

def _rename_indexes(
self, name_dict: Mapping[Any, Hashable], dims_dict: Mapping[Any, Hashable]
Expand Down Expand Up @@ -5168,7 +5170,7 @@ def _get_stack_index(
if dim in self._variables:
var = self._variables[dim]
else:
_, _, var = _get_virtual_variable(self._variables, dim, self.dims)
_, _, var = _get_virtual_variable(self._variables, dim, self.sizes)
# dummy index (only `stack_coords` will be used to construct the multi-index)
stack_index = PandasIndex([0], dim)
stack_coords = {dim: var}
Expand All @@ -5195,7 +5197,7 @@ def _stack_once(
if any(d in var.dims for d in dims):
add_dims = [d for d in dims if d not in var.dims]
vdims = list(var.dims) + add_dims
shape = [self.dims[d] for d in vdims]
shape = [self.sizes[d] for d in vdims]
exp_var = var.set_dims(vdims, shape)
stacked_var = exp_var.stack(**{new_dim: dims})
new_variables[name] = stacked_var
Expand Down Expand Up @@ -6351,15 +6353,15 @@ def dropna(
if subset is None:
subset = iter(self.data_vars)

count = np.zeros(self.dims[dim], dtype=np.int64)
count = np.zeros(self.sizes[dim], dtype=np.int64)
size = np.int_(0) # for type checking

for k in subset:
array = self._variables[k]
if dim in array.dims:
dims = [d for d in array.dims if d != dim]
count += np.asarray(array.count(dims))
size += math.prod([self.dims[d] for d in dims])
size += math.prod([self.sizes[d] for d in dims])

if thresh is not None:
mask = count >= thresh
Expand Down Expand Up @@ -7136,7 +7138,7 @@ def _normalize_dim_order(
f"Dataset: {list(self.dims)}"
)

ordered_dims = {k: self.dims[k] for k in dim_order}
ordered_dims = {k: self.sizes[k] for k in dim_order}

return ordered_dims

Expand Down Expand Up @@ -7396,7 +7398,7 @@ def to_dask_dataframe(
var = self.variables[name]
except KeyError:
# dimension without a matching coordinate
size = self.dims[name]
size = self.sizes[name]
data = da.arange(size, chunks=size, dtype=np.int64)
var = Variable((name,), data)

Expand Down Expand Up @@ -7469,7 +7471,7 @@ def to_dict(
d: dict = {
"coords": {},
"attrs": decode_numpy_dict_values(self.attrs),
"dims": dict(self.dims),
"dims": dict(self.sizes),
"data_vars": {},
}
for k in self.coords:
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ def dataset_repr(ds):


def diff_dim_summary(a, b):
if a.dims != b.dims:
if a.sizes != b.sizes:
return f"Differing dimensions:\n ({dim_summary(a)}) != ({dim_summary(b)})"
else:
return ""
Expand Down
11 changes: 6 additions & 5 deletions xarray/core/formatting_html.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,18 @@ def short_data_repr_html(array) -> str:
return f"<pre>{text}</pre>"


def format_dims(dims, dims_with_index) -> str:
if not dims:
def format_dims(dim_sizes, dims_with_index) -> str:
if not dim_sizes:
return ""

dim_css_map = {
dim: " class='xr-has-index'" if dim in dims_with_index else "" for dim in dims
dim: " class='xr-has-index'" if dim in dims_with_index else ""
for dim in dim_sizes
}

dims_li = "".join(
f"<li><span{dim_css_map[dim]}>" f"{escape(str(dim))}</span>: {size}</li>"
for dim, size in dims.items()
for dim, size in dim_sizes.items()
)

return f"<ul class='xr-dim-list'>{dims_li}</ul>"
Expand Down Expand Up @@ -204,7 +205,7 @@ def _mapping_section(


def dim_section(obj) -> str:
dim_list = format_dims(obj.dims, obj.xindexes.dims)
dim_list = format_dims(obj.sizes, obj.xindexes.dims)

return collapsible_section(
"Dimensions", inline_details=dim_list, enabled=False, collapsed=True
Expand Down
3 changes: 2 additions & 1 deletion xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from xarray.core.pycompat import integer_types
from xarray.core.types import Dims, QuantileMethods, T_DataArray, T_Xarray
from xarray.core.utils import (
FrozenMappingWarningOnValuesAccess,
either_dict_or_kwargs,
hashable,
is_scalar,
Expand Down Expand Up @@ -1519,7 +1520,7 @@ def dims(self) -> Frozen[Hashable, int]:
if self._dims is None:
self._dims = self._obj.isel({self._group_dim: self._group_indices[0]}).dims

return self._dims
return FrozenMappingWarningOnValuesAccess(self._dims)

def map(
self,
Expand Down
54 changes: 54 additions & 0 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,15 @@
Collection,
Container,
Hashable,
ItemsView,
Iterable,
Iterator,
KeysView,
Mapping,
MutableMapping,
MutableSet,
Sequence,
ValuesView,
)
from enum import Enum
from typing import (
Expand Down Expand Up @@ -473,6 +476,57 @@ def FrozenDict(*args, **kwargs) -> Frozen:
return Frozen(dict(*args, **kwargs))


class FrozenMappingWarningOnValuesAccess(Frozen[K, V]):
"""
Class which behaves like a Mapping but warns if the values are accessed.
Temporary object to aid in deprecation cycle of `Dataset.dims` (see GH issue #8496).
`Dataset.dims` is being changed from returning a mapping of dimension names to lengths to just
returning a frozen set of dimension names (to increase consistency with `DataArray.dims`).
This class retains backwards compatibility but raises a warning only if the return value
of ds.dims is used like a dictionary (i.e. it doesn't raise a warning if used in a way that
would also be valid for a FrozenSet, e.g. iteration).
"""

__slots__ = ("mapping",)

def _warn(self) -> None:
warnings.warn(
"The return type of `Dataset.dims` will be changed to return a set of dimension names in future, "
"in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, "
"please use `Dataset.sizes`.",
FutureWarning,
)

def __getitem__(self, key: K) -> V:
self._warn()
return super().__getitem__(key)

@overload
def get(self, key: K, /) -> V | None:
...

@overload
def get(self, key: K, /, default: V | T) -> V | T:
...

def get(self, key: K, default: T | None = None) -> V | T | None:
self._warn()
return super().get(key, default)

def keys(self) -> KeysView[K]:
self._warn()
return super().keys()

def items(self) -> ItemsView[K, V]:
self._warn()
return super().items()

def values(self) -> ValuesView[V]:
self._warn()
return super().values()


class HybridMappingProxy(Mapping[K, V]):
"""Implements the Mapping interface. Uses the wrapped mapping for item lookup
and a separate wrapped keys collection for iteration.
Expand Down
Loading

0 comments on commit 299abd6

Please sign in to comment.