Skip to content

Commit

Permalink
New properties Dataset.sizes and DataArray.sizes
Browse files Browse the repository at this point in the history
This allows for consistent access to dimension lengths on ``Dataset`` and
``DataArray``

xref pydata#921 (doesn't resolve it 100%, but should help significantly)
  • Loading branch information
shoyer committed Nov 3, 2016
1 parent a4f5ec2 commit 05079ae
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 73 deletions.
5 changes: 5 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ Enhancements
:py:meth:`~xarray.DataArray.cumprod`. By `Phillip J. Wolfram
<https://github.com/pwolfram>`_.

- New properties :py:attr:`Dataset.sizes` and :py:attr:`DataArray.sizes` for
providing consistent access to dimension length on both ``Dataset`` and
``DataArray`` (:issue:`921`).
By `Stephan Hoyer <https://github.com/shoyer>`_.

Bug fixes
~~~~~~~~~
- ``groupby_bins`` now restores empty bins by default (:issue:`1019`).
Expand Down
43 changes: 30 additions & 13 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,36 @@ def pipe(self, func, *args, **kwargs):
else:
return func(self, *args, **kwargs)

def squeeze(self, dim=None):
"""Return a new object with squeezed data.
Parameters
----------
dim : None or str or tuple of str, optional
Selects a subset of the length one dimensions. If a dimension is
selected with length greater than one, an error is raised. If
None, all length one dimensions are squeezed.
Returns
-------
squeezed : same type as caller
This object, but with with all or a subset of the dimensions of
length 1 removed.
See Also
--------
numpy.squeeze
"""
if dim is None:
dim = [d for d, s in self.sizes.item() if s == 1]
else:
if isinstance(dim, basestring):
dim = [dim]
if any(self.sizes[k] > 1 for k in dim):
raise ValueError('cannot select a dimension to squeeze out '
'which has length greater than one')
return self.isel(**dict((d, 0) for d in dim))

def groupby(self, group, squeeze=True):
"""Returns a GroupBy object for performing grouped operations.
Expand Down Expand Up @@ -615,19 +645,6 @@ def __exit__(self, exc_type, exc_value, traceback):
__or__ = __div__ = __eq__ = __ne__ = not_implemented


def squeeze(xarray_obj, dims, dim=None):
"""Squeeze the dims of an xarray object."""
if dim is None:
dim = [d for d, s in iteritems(dims) if s == 1]
else:
if isinstance(dim, basestring):
dim = [dim]
if any(dims[k] > 1 for k in dim):
raise ValueError('cannot select a dimension to squeeze out '
'which has length greater than one')
return xarray_obj.isel(**dict((d, 0) for d in dim))


def _maybe_promote(dtype):
"""Simpler equivalent of pandas.core.common._maybe_promote"""
# N.B. these casting rules should match pandas
Expand Down
76 changes: 45 additions & 31 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections
import contextlib
import functools
import warnings
Expand All @@ -7,6 +8,7 @@

from ..plot.plot import _PlotMethods

from . import formatting
from . import indexing
from . import groupby
from . import rolling
Expand Down Expand Up @@ -72,18 +74,17 @@ def _infer_coords_and_dims(shape, coords, dims):
if dim not in new_coords:
new_coords[dim] = default_index_coordinate(dim, size)

sizes = dict(zip(dims, shape))
for k, v in new_coords.items():
if any(d not in dims for d in v.dims):
raise ValueError('coordinate %s has dimensions %s, but these '
'are not a subset of the DataArray '
'dimensions %s' % (k, v.dims, dims))

for d, s in zip(v.dims, v.shape):
if s != sizes[d]:
if s != self.sizes[d]:
raise ValueError('conflicting sizes for dimension %r: '
'length %s on the data but length %s on '
'coordinate %r' % (d, sizes[d], s, k))
'coordinate %r' % (d, self.sizes[d], s, k))

assert_unique_multiindex_level_names(new_coords)

Expand All @@ -110,6 +111,31 @@ def __setitem__(self, key, value):
self.data_array[pos_indexers] = value


class DataArraySizes(collections.Mapping, formatting.ReprMixin):
def __init__(self, array):
self._array = array

def __getitem__(self, key):
try:
index = self._array.dims.index(key)
except ValueError:
raise KeyError(key)
return self._array.shape[index]

def __contains__(self, key):
return key in self._array.dims

def __iter__(self):
return iter(self._array.dims)

def __len__(self):
return len(self._array.dims)

def __unicode__(self):
contents = ', '.join(u'%s: %s' % (k, v) for k, v in self.items())
return u'<%s (%s)>' % (type(self).__name__, contents)


class _ThisArray(object):
"""An instance of this object is used as the key corresponding to the
variable when converting arbitrary DataArray objects to datasets
Expand Down Expand Up @@ -411,14 +437,29 @@ def to_index(self):

@property
def dims(self):
"""Dimension names associated with this array."""
"""Tuple of dimension names associated with this array.
Note that the type of this property is inconsistent with `Dataset.dims`.
See `Dataset.sizes` and `DataArray.sizes` for consistently named
properties.
"""
return self.variable.dims

@dims.setter
def dims(self, value):
raise AttributeError('you cannot assign dims on a DataArray. Use '
'.rename() or .swap_dims() instead.')

@property
def sizes(self):
"""Mapping from dimension names to lengths.
See also
--------
Dataset.sizes
"""
return DataArraySizes(self)

def _item_key_to_dict(self, key):
if utils.is_dict_like(key):
return key
Expand Down Expand Up @@ -911,33 +952,6 @@ def transpose(self, *dims):
variable = self.variable.transpose(*dims)
return self._replace(variable)

def squeeze(self, dim=None):
"""Return a new DataArray object with squeezed data.
Parameters
----------
dim : None or str or tuple of str, optional
Selects a subset of the length one dimensions. If a dimension is
selected with length greater than one, an error is raised. If
None, all length one dimensions are squeezed.
Returns
-------
squeezed : DataArray
This array, but with with all or a subset of the dimensions of
length 1 removed.
Notes
-----
Although this operation returns a view of this array's data, it is
not lazy -- the data will be fully loaded.
See Also
--------
numpy.squeeze
"""
return squeeze(self, dict(zip(self.dims, self.shape)), dim)

def drop(self, labels, dim=None):
"""Drop coordinates or index labels from this DataArray.
Expand Down
45 changes: 16 additions & 29 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,11 +291,25 @@ def attrs(self, value):
def dims(self):
"""Mapping from dimension names to lengths.
This dictionary cannot be modified directly, but is updated when adding
new variables.
Cannot be modified directly, but is updated when adding new variables.
Note that type of this object differs from `DataArray.dims`.
See `Dataset.sizes` and `DataArray.sizes` for consistently named
properties.
"""
return Frozen(SortedKeysDict(self._dims))

@property
def sizes(self):
"""Mapping from dimension names to lengths.
Cannot be modified directly, but is updated when adding new variables.
This is an alias for `Dataset.dims` provided for the benefit of
consistency with `DataArray.sizes`.
"""
return self.dims

def load(self):
"""Manually trigger loading of this dataset's data from disk or a
remote source into memory and return this dataset.
Expand Down Expand Up @@ -1584,33 +1598,6 @@ def transpose(self, *dims):
def T(self):
return self.transpose()

def squeeze(self, dim=None):
"""Returns a new dataset with squeezed data.
Parameters
----------
dim : None or str or tuple of str, optional
Selects a subset of the length one dimensions. If a dimension is
selected with length greater than one, an error is raised. If
None, all length one dimensions are squeezed.
Returns
-------
squeezed : Dataset
This dataset, but with with all or a subset of the dimensions of
length 1 removed.
Notes
-----
Although this operation returns a view of each variable's data, it is
not lazy -- all variable data will be fully loaded.
See Also
--------
numpy.squeeze
"""
return common.squeeze(self, self.dims, dim)

def dropna(self, dim, how='any', thresh=None, subset=None):
"""Returns a new dataset with dropped labels for missing values along
the provided dimension.
Expand Down
16 changes: 16 additions & 0 deletions xarray/test/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,22 @@ def test_dims(self):
with self.assertRaisesRegexp(AttributeError, 'you cannot assign'):
arr.dims = ('w', 'z')

def test_sizes(self):
array = DataArray(np.zeros((3, 4)), dims=['x', 'y'])
self.assertEqual(array.sizes, {'x': 3, 'y': 4})

with self.assertRaisesRegexp(KeyError, repr('foo')):
array.sizes['foo']
self.assertEqual(array.sizes['x'], 3)
self.assertEqual(array.sizes['y'], 4)

self.assertIn('x', array.sizes)
self.assertNotIn('foo', array.sizes)

self.assertEqual(tuple(array.sizes), array.dims)
self.assertEqual(len(array.sizes), 2)
self.assertEqual(repr(array.sizes), u'<DataArraySizes (x: 3, y: 4)>')

def test_encoding(self):
expected = {'foo': 'bar'}
self.dv.encoding['foo'] = 'bar'
Expand Down
1 change: 1 addition & 0 deletions xarray/test/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def test_properties(self):
self.assertEqual(ds.dims,
{'dim1': 8, 'dim2': 9, 'dim3': 10, 'time': 20})
self.assertEqual(list(ds.dims), sorted(ds.dims))
self.assertEqual(ds.sizes, ds.dims)

# These exact types aren't public API, but this makes sure we don't
# change them inadvertently:
Expand Down

0 comments on commit 05079ae

Please sign in to comment.