Skip to content

Commit

Permalink
Merge pull request pydata#189 from shoyer/Dataset.apply
Browse files Browse the repository at this point in the history
Implementation of Dataset.apply method
  • Loading branch information
shoyer committed Jul 31, 2014
2 parents 2565be1 + 4548d10 commit 9597d9e
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 0 deletions.
2 changes: 2 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ Computations
.. autosummary::
:toctree: generated/

Dataset.apply
Dataset.reduce
Dataset.all
Dataset.any
Dataset.argmax
Expand Down
29 changes: 29 additions & 0 deletions test/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,3 +779,32 @@ def test_reduce_keep_attrs(self):
ds = data.mean(keep_attrs=True)
self.assertEqual(len(ds.attrs), len(_attrs))
self.assertTrue(ds.attrs, attrs)

def test_apply(self):
data = create_test_data()
data.attrs['foo'] = 'bar'

self.assertDatasetIdentical(data.apply(np.mean), data.mean())
self.assertDatasetIdentical(data.apply(np.mean, keep_attrs=True),
data.mean(keep_attrs=True))

self.assertDatasetIdentical(data.apply(lambda x: x, keep_attrs=True),
data.drop_vars('time'))

actual = data.apply(np.mean, to=['var1', 'var2', 'var3'])
self.assertDatasetIdentical(actual, data.mean())

actual = data.apply(np.mean, to='var1')
modified = data.select_vars('var1').mean()
unmodified = data.select_vars('var2', 'var3')
expected = modified.merge(unmodified)
self.assertDatasetIdentical(actual, expected)

with self.assertRaisesRegexp(ValueError, 'does not contain'):
data.apply(np.mean, to='foobarbaz')

def scale(x, multiple=1):
return multiple * x

actual = data.apply(scale, multiple=2)
self.assertDataArrayEqual(actual['var1'], 2 * data['var1'])
44 changes: 44 additions & 0 deletions xray/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,6 +1119,50 @@ def reduce(self, func, dimension=None, keep_attrs=False, **kwargs):

return Dataset(variables=variables, attributes=attrs)

def apply(self, func, to=None, keep_attrs=False, **kwargs):
"""Apply a function over noncoordinates in this dataset.
Parameters
----------
func : function
Function which can be called in the form `f(x, **kwargs)` to
transform each DataArray `x` in this dataset into another
DataArray.
to : str or sequence of str, optional
Explicit list of noncoordinates in this dataset to which to apply
`func`. Unlisted noncoordinates are passed through unchanged. By
default, `func` is applied to all noncoordinates in this dataset.
keep_attrs : bool, optional
If True, the datasets's attributes (`attrs`) will be copied from
the original object to the new one. If False, the new object will
be returned without attributes.
**kwargs : dict
Additional keyword arguments passed on to `func`.
Returns
-------
applied : Dataset
Resulting dataset from applying over each noncoordinate.
Coordinates which are no longer used as the dimension of a
noncoordinate are dropped.
"""
if to is not None:
to = set([to] if isinstance(to, basestring) else to)
bad_to = to - set(self.noncoordinates)
if bad_to:
raise ValueError('Dataset does not contain the '
'noncoordinates: %r' % list(bad_to))
else:
to = set(self.noncoordinates)

variables = OrderedDict()
for name, var in iteritems(self.noncoordinates):
variables[name] = func(var, **kwargs) if name in to else var

attrs = self.attrs if keep_attrs else {}

return Dataset(variables, attrs)

@classmethod
def concat(cls, datasets, dimension='concat_dimension', indexers=None,
mode='different', concat_over=None, compat='equals'):
Expand Down

0 comments on commit 9597d9e

Please sign in to comment.