Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of Dataset.apply method #189

Merged
merged 1 commit into from
Jul 31, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ Computations
.. autosummary::
:toctree: generated/

Dataset.apply
Dataset.reduce
Dataset.all
Dataset.any
Dataset.argmax
Expand Down
29 changes: 29 additions & 0 deletions test/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,3 +779,32 @@ def test_reduce_keep_attrs(self):
ds = data.mean(keep_attrs=True)
self.assertEqual(len(ds.attrs), len(_attrs))
self.assertTrue(ds.attrs, attrs)

def test_apply(self):
data = create_test_data()
data.attrs['foo'] = 'bar'

self.assertDatasetIdentical(data.apply(np.mean), data.mean())
self.assertDatasetIdentical(data.apply(np.mean, keep_attrs=True),
data.mean(keep_attrs=True))

self.assertDatasetIdentical(data.apply(lambda x: x, keep_attrs=True),
data.drop_vars('time'))

actual = data.apply(np.mean, to=['var1', 'var2', 'var3'])
self.assertDatasetIdentical(actual, data.mean())

actual = data.apply(np.mean, to='var1')
modified = data.select_vars('var1').mean()
unmodified = data.select_vars('var2', 'var3')
expected = modified.merge(unmodified)
self.assertDatasetIdentical(actual, expected)

with self.assertRaisesRegexp(ValueError, 'does not contain'):
data.apply(np.mean, to='foobarbaz')

def scale(x, multiple=1):
return multiple * x

actual = data.apply(scale, multiple=2)
self.assertDataArrayEqual(actual['var1'], 2 * data['var1'])
44 changes: 44 additions & 0 deletions xray/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,6 +1091,50 @@ def reduce(self, func, dimension=None, keep_attrs=False, **kwargs):

return Dataset(variables=variables, attributes=attrs)

def apply(self, func, to=None, keep_attrs=False, **kwargs):
"""Apply a function over noncoordinates in this dataset.

Parameters
----------
func : function
Function which can be called in the form `f(x, **kwargs)` to
transform each DataArray `x` in this dataset into another
DataArray.
to : str or sequence of str, optional
Explicit list of noncoordinates in this dataset to which to apply
`func`. Unlisted noncoordinates are passed through unchanged. By
default, `func` is applied to all noncoordinates in this dataset.
keep_attrs : bool, optional
If True, the datasets's attributes (`attrs`) will be copied from
the original object to the new one. If False, the new object will
be returned without attributes.
**kwargs : dict
Additional keyword arguments passed on to `func`.

Returns
-------
applied : Dataset
Resulting dataset from applying over each noncoordinate.
Coordinates which are no longer used as the dimension of a
noncoordinate are dropped.
"""
if to is not None:
to = set([to] if isinstance(to, basestring) else to)
bad_to = to - set(self.noncoordinates)
if bad_to:
raise ValueError('Dataset does not contain the '
'noncoordinates: %r' % list(bad_to))
else:
to = set(self.noncoordinates)

variables = OrderedDict()
for name, var in iteritems(self.noncoordinates):
variables[name] = func(var, **kwargs) if name in to else var

attrs = self.attrs if keep_attrs else {}

return Dataset(variables, attrs)

@classmethod
def concat(cls, datasets, dimension='concat_dimension', indexers=None,
mode='different', concat_over=None, compat='equals'):
Expand Down