Skip to content

Commit

Permalink
New properties Dataset.sizes and DataArray.sizes (#1076)
Browse files Browse the repository at this point in the history
This allows for consistent access to dimension lengths on ``Dataset`` and
``DataArray``

xref #921 (doesn't resolve it 100%, but should help significantly)
  • Loading branch information
shoyer authored Nov 4, 2016
1 parent a4f5ec2 commit 3f490a3
Show file tree
Hide file tree
Showing 10 changed files with 101 additions and 112 deletions.
2 changes: 2 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Attributes
:toctree: generated/

Dataset.dims
Dataset.sizes
Dataset.data_vars
Dataset.coords
Dataset.attrs
Expand Down Expand Up @@ -187,6 +188,7 @@ Attributes
DataArray.data
DataArray.coords
DataArray.dims
DataArray.sizes
DataArray.name
DataArray.attrs
DataArray.encoding
Expand Down
5 changes: 5 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ Enhancements
:py:meth:`~xarray.DataArray.cumprod`. By `Phillip J. Wolfram
<https://github.com/pwolfram>`_.

- New properties :py:attr:`Dataset.sizes` and :py:attr:`DataArray.sizes` for
providing consistent access to dimension length on both ``Dataset`` and
``DataArray`` (:issue:`921`).
By `Stephan Hoyer <https://github.com/shoyer>`_.

Bug fixes
~~~~~~~~~
- ``groupby_bins`` now restores empty bins by default (:issue:`1019`).
Expand Down
70 changes: 54 additions & 16 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import numpy as np
import pandas as pd

from .pycompat import basestring, iteritems, suppress, dask_array_type, bytes_type
from .pycompat import (basestring, iteritems, suppress, dask_array_type,
OrderedDict)
from . import formatting
from .utils import SortedKeysDict, not_implemented
from .utils import SortedKeysDict, not_implemented, Frozen


class ImplementsArrayReduce(object):
Expand Down Expand Up @@ -124,6 +125,8 @@ def wrapped_func(self, **kwargs):


class AbstractArray(ImplementsArrayReduce, formatting.ReprMixin):
"""Shared base class for DataArray and Variable."""

def __bool__(self):
return bool(self.values)

Expand Down Expand Up @@ -186,6 +189,18 @@ def _get_axis_num(self, dim):
raise ValueError("%r not found in array dimensions %r" %
(dim, self.dims))

@property
def sizes(self):
"""Ordered mapping from dimension names to lengths.
Immutable.
See also
--------
Dataset.sizes
"""
return Frozen(OrderedDict(zip(self.dims, self.shape)))


class AttrAccessMixin(object):
"""Mixin class that allows getting keys with attribute access
Expand Down Expand Up @@ -231,7 +246,43 @@ def __dir__(self):
return sorted(set(dir(type(self)) + extra_attrs))


class BaseDataObject(AttrAccessMixin):
class SharedMethodsMixin(object):
"""Shared methods for Dataset, DataArray and Variable."""

def squeeze(self, dim=None):
"""Return a new object with squeezed data.
Parameters
----------
dim : None or str or tuple of str, optional
Selects a subset of the length one dimensions. If a dimension is
selected with length greater than one, an error is raised. If
None, all length one dimensions are squeezed.
Returns
-------
squeezed : same type as caller
This object, but with with all or a subset of the dimensions of
length 1 removed.
See Also
--------
numpy.squeeze
"""
if dim is None:
dim = [d for d, s in self.sizes.items() if s == 1]
else:
if isinstance(dim, basestring):
dim = [dim]
if any(self.sizes[k] > 1 for k in dim):
raise ValueError('cannot select a dimension to squeeze out '
'which has length greater than one')
return self.isel(**{d: 0 for d in dim})


class BaseDataObject(SharedMethodsMixin, AttrAccessMixin):
"""Shared base class for Dataset and DataArray."""

def _calc_assign_results(self, kwargs):
results = SortedKeysDict()
for k, v in kwargs.items():
Expand Down Expand Up @@ -615,19 +666,6 @@ def __exit__(self, exc_type, exc_value, traceback):
__or__ = __div__ = __eq__ = __ne__ = not_implemented


def squeeze(xarray_obj, dims, dim=None):
"""Squeeze the dims of an xarray object."""
if dim is None:
dim = [d for d, s in iteritems(dims) if s == 1]
else:
if isinstance(dim, basestring):
dim = [dim]
if any(dims[k] > 1 for k in dim):
raise ValueError('cannot select a dimension to squeeze out '
'which has length greater than one')
return xarray_obj.isel(**dict((d, 0) for d in dim))


def _maybe_promote(dtype):
"""Simpler equivalent of pandas.core.common._maybe_promote"""
# N.B. these casting rules should match pandas
Expand Down
37 changes: 7 additions & 30 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import contextlib
import functools
import warnings

Expand All @@ -13,7 +12,7 @@
from . import ops
from . import utils
from .alignment import align
from .common import AbstractArray, BaseDataObject, squeeze
from .common import AbstractArray, BaseDataObject
from .coordinates import (DataArrayCoordinates, LevelCoordinates,
Indexes)
from .dataset import Dataset
Expand Down Expand Up @@ -411,7 +410,12 @@ def to_index(self):

@property
def dims(self):
"""Dimension names associated with this array."""
"""Tuple of dimension names associated with this array.
Note that the type of this property is inconsistent with `Dataset.dims`.
See `Dataset.sizes` and `DataArray.sizes` for consistently named
properties.
"""
return self.variable.dims

@dims.setter
Expand Down Expand Up @@ -911,33 +915,6 @@ def transpose(self, *dims):
variable = self.variable.transpose(*dims)
return self._replace(variable)

def squeeze(self, dim=None):
"""Return a new DataArray object with squeezed data.
Parameters
----------
dim : None or str or tuple of str, optional
Selects a subset of the length one dimensions. If a dimension is
selected with length greater than one, an error is raised. If
None, all length one dimensions are squeezed.
Returns
-------
squeezed : DataArray
This array, but with with all or a subset of the dimensions of
length 1 removed.
Notes
-----
Although this operation returns a view of this array's data, it is
not lazy -- the data will be fully loaded.
See Also
--------
numpy.squeeze
"""
return squeeze(self, dict(zip(self.dims, self.shape)), dim)

def drop(self, labels, dim=None):
"""Drop coordinates or index labels from this DataArray.
Expand Down
49 changes: 20 additions & 29 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,11 +291,29 @@ def attrs(self, value):
def dims(self):
"""Mapping from dimension names to lengths.
This dictionary cannot be modified directly, but is updated when adding
new variables.
Cannot be modified directly, but is updated when adding new variables.
Note that type of this object differs from `DataArray.dims`.
See `Dataset.sizes` and `DataArray.sizes` for consistently named
properties.
"""
return Frozen(SortedKeysDict(self._dims))

@property
def sizes(self):
"""Mapping from dimension names to lengths.
Cannot be modified directly, but is updated when adding new variables.
This is an alias for `Dataset.dims` provided for the benefit of
consistency with `DataArray.sizes`.
See also
--------
DataArray.sizes
"""
return self.dims

def load(self):
"""Manually trigger loading of this dataset's data from disk or a
remote source into memory and return this dataset.
Expand Down Expand Up @@ -1584,33 +1602,6 @@ def transpose(self, *dims):
def T(self):
return self.transpose()

def squeeze(self, dim=None):
"""Returns a new dataset with squeezed data.
Parameters
----------
dim : None or str or tuple of str, optional
Selects a subset of the length one dimensions. If a dimension is
selected with length greater than one, an error is raised. If
None, all length one dimensions are squeezed.
Returns
-------
squeezed : Dataset
This dataset, but with with all or a subset of the dimensions of
length 1 removed.
Notes
-----
Although this operation returns a view of each variable's data, it is
not lazy -- all variable data will be fully loaded.
See Also
--------
numpy.squeeze
"""
return common.squeeze(self, self.dims, dim)

def dropna(self, dim, how='any', thresh=None, subset=None):
"""Returns a new dataset with dropped labels for missing values along
the provided dimension.
Expand Down
5 changes: 1 addition & 4 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,7 @@ def __init__(self, obj, group, squeeze=False, grouper=None, bins=None,
raise ValueError("`group` must have a 'dims' attribute")
group_dim, = group.dims

try:
expected_size = obj.dims[group_dim]
except TypeError:
expected_size = obj.shape[obj.get_axis_num(group_dim)]
expected_size = obj.sizes[group_dim]
if group.size != expected_size:
raise ValueError('the group variable\'s length does not '
'match the length of this variable along its '
Expand Down
36 changes: 3 additions & 33 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from collections import defaultdict
import functools
import itertools
import warnings

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -192,7 +191,8 @@ def _as_array_or_item(data):
return data


class Variable(common.AbstractArray, utils.NdimSizeLenMixin):
class Variable(common.AbstractArray, common.SharedMethodsMixin,
utils.NdimSizeLenMixin):

"""A netcdf-like variable consisting of dimensions, data and attributes
which describe a single Array. A single Variable object is not fully
Expand Down Expand Up @@ -678,34 +678,6 @@ def transpose(self, *dims):
data = ops.transpose(self.data, axes)
return type(self)(dims, data, self._attrs, self._encoding, fastpath=True)

def squeeze(self, dim=None):
"""Return a new Variable object with squeezed data.
Parameters
----------
dim : None or str or tuple of str, optional
Selects a subset of the length one dimensions. If a dimension is
selected with length greater than one, an error is raised. If
None, all length one dimensions are squeezed.
Returns
-------
squeezed : Variable
This array, but with with all or a subset of the dimensions of
length 1 removed.
Notes
-----
Although this operation returns a view of this variable's data, it is
not lazy -- the data will be fully loaded.
See Also
--------
numpy.squeeze
"""
dims = dict(zip(self.dims, self.shape))
return common.squeeze(self, dims, dim)

def expand_dims(self, dims, shape=None):
"""Return a new variable with expanded dimensions.
Expand Down Expand Up @@ -814,8 +786,7 @@ def _unstack_once(self, dims, old_dim):
raise ValueError('cannot create a new dimension with the same '
'name as an existing dimension')

axis = self.get_axis_num(old_dim)
if np.prod(new_dim_sizes) != self.shape[axis]:
if np.prod(new_dim_sizes) != self.sizes[old_dim]:
raise ValueError('the product of the new dimension sizes must '
'equal the size of the old dimension')

Expand Down Expand Up @@ -914,7 +885,6 @@ def reduce(self, func, dim=None, axis=None, keep_attrs=False,
dims = [adim for n, adim in enumerate(self.dims)
if n not in removed_axes]


attrs = self._attrs if keep_attrs else None

return Variable(dims, data, attrs=attrs)
Expand Down
7 changes: 7 additions & 0 deletions xarray/test/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,13 @@ def test_dims(self):
with self.assertRaisesRegexp(AttributeError, 'you cannot assign'):
arr.dims = ('w', 'z')

def test_sizes(self):
array = DataArray(np.zeros((3, 4)), dims=['x', 'y'])
self.assertEqual(array.sizes, {'x': 3, 'y': 4})
self.assertEqual(tuple(array.sizes), array.dims)
with self.assertRaises(TypeError):
array.sizes['foo'] = 5

def test_encoding(self):
expected = {'foo': 'bar'}
self.dv.encoding['foo'] = 'bar'
Expand Down
1 change: 1 addition & 0 deletions xarray/test/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def test_properties(self):
self.assertEqual(ds.dims,
{'dim1': 8, 'dim2': 9, 'dim3': 10, 'time': 20})
self.assertEqual(list(ds.dims), sorted(ds.dims))
self.assertEqual(ds.sizes, ds.dims)

# These exact types aren't public API, but this makes sure we don't
# change them inadvertently:
Expand Down
1 change: 1 addition & 0 deletions xarray/test/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def test_properties(self):
self.assertEqual(v.dtype, float)
self.assertEqual(v.shape, (10,))
self.assertEqual(v.size, 10)
self.assertEqual(v.sizes, {'time': 10})
self.assertEqual(v.nbytes, 80)
self.assertEqual(v.ndim, 1)
self.assertEqual(len(v), 10)
Expand Down

0 comments on commit 3f490a3

Please sign in to comment.