From 4548d1015c38dc7c1c324b157da5939f232f4b46 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 24 Jul 2014 23:17:10 -0700 Subject: [PATCH] Implemented of Dataset.apply method Fixes #140 --- doc/api.rst | 2 ++ test/test_dataset.py | 29 +++++++++++++++++++++++++++++ xray/dataset.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 75 insertions(+) diff --git a/doc/api.rst b/doc/api.rst index b843af17900..8fe61c643ba 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -77,6 +77,8 @@ Computations .. autosummary:: :toctree: generated/ + Dataset.apply + Dataset.reduce Dataset.all Dataset.any Dataset.argmax diff --git a/test/test_dataset.py b/test/test_dataset.py index 6136e752027..992ffcb7be6 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -779,3 +779,32 @@ def test_reduce_keep_attrs(self): ds = data.mean(keep_attrs=True) self.assertEqual(len(ds.attrs), len(_attrs)) self.assertTrue(ds.attrs, attrs) + + def test_apply(self): + data = create_test_data() + data.attrs['foo'] = 'bar' + + self.assertDatasetIdentical(data.apply(np.mean), data.mean()) + self.assertDatasetIdentical(data.apply(np.mean, keep_attrs=True), + data.mean(keep_attrs=True)) + + self.assertDatasetIdentical(data.apply(lambda x: x, keep_attrs=True), + data.drop_vars('time')) + + actual = data.apply(np.mean, to=['var1', 'var2', 'var3']) + self.assertDatasetIdentical(actual, data.mean()) + + actual = data.apply(np.mean, to='var1') + modified = data.select_vars('var1').mean() + unmodified = data.select_vars('var2', 'var3') + expected = modified.merge(unmodified) + self.assertDatasetIdentical(actual, expected) + + with self.assertRaisesRegexp(ValueError, 'does not contain'): + data.apply(np.mean, to='foobarbaz') + + def scale(x, multiple=1): + return multiple * x + + actual = data.apply(scale, multiple=2) + self.assertDataArrayEqual(actual['var1'], 2 * data['var1']) diff --git a/xray/dataset.py b/xray/dataset.py index 58d5c413f93..e2a197e50a8 100644 --- a/xray/dataset.py +++ b/xray/dataset.py @@ -1091,6 +1091,50 @@ def reduce(self, func, dimension=None, keep_attrs=False, **kwargs): return Dataset(variables=variables, attributes=attrs) + def apply(self, func, to=None, keep_attrs=False, **kwargs): + """Apply a function over noncoordinates in this dataset. + + Parameters + ---------- + func : function + Function which can be called in the form `f(x, **kwargs)` to + transform each DataArray `x` in this dataset into another + DataArray. + to : str or sequence of str, optional + Explicit list of noncoordinates in this dataset to which to apply + `func`. Unlisted noncoordinates are passed through unchanged. By + default, `func` is applied to all noncoordinates in this dataset. + keep_attrs : bool, optional + If True, the datasets's attributes (`attrs`) will be copied from + the original object to the new one. If False, the new object will + be returned without attributes. + **kwargs : dict + Additional keyword arguments passed on to `func`. + + Returns + ------- + applied : Dataset + Resulting dataset from applying over each noncoordinate. + Coordinates which are no longer used as the dimension of a + noncoordinate are dropped. + """ + if to is not None: + to = set([to] if isinstance(to, basestring) else to) + bad_to = to - set(self.noncoordinates) + if bad_to: + raise ValueError('Dataset does not contain the ' + 'noncoordinates: %r' % list(bad_to)) + else: + to = set(self.noncoordinates) + + variables = OrderedDict() + for name, var in iteritems(self.noncoordinates): + variables[name] = func(var, **kwargs) if name in to else var + + attrs = self.attrs if keep_attrs else {} + + return Dataset(variables, attrs) + @classmethod def concat(cls, datasets, dimension='concat_dimension', indexers=None, mode='different', concat_over=None, compat='equals'):