Skip to content

Commit

Permalink
Add choose method (#3)
Browse files Browse the repository at this point in the history
* Add `choose_any` and `choose_all` methods

* Update docstring

* Incorporate @dcherian's feedback

* consolidate functions

* Default to mode='any'
  • Loading branch information
andersy005 authored Aug 16, 2021
1 parent 973126a commit b56be0d
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 1 deletion.
1 change: 1 addition & 0 deletions ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ dependencies:
- pytest-xdist
- xarray>=0.18
- zarr
- toolz
- pip:
- -r ../requirements.txt
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pydantic
xarray
toolz
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ extend-ignore = E203,E501,E402,W605

[isort]
known_first_party=xcollection
known_third_party=pkg_resources,pydantic,pytest,setuptools,xarray
known_third_party=pkg_resources,pydantic,pytest,setuptools,toolz,xarray
multi_line_output=3
include_trailing_comma=True
force_grid_wrap=0
Expand Down
28 changes: 28 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import xcollection

ds = xr.tutorial.open_dataset('rasm')
dsa = xr.tutorial.open_dataset('air_temperature')


@pytest.mark.parametrize('datasets', [None, {'a': ds, 'b': ds}, {'test': ds.Tair}])
Expand Down Expand Up @@ -69,3 +70,30 @@ def test_getitem():
def test_iter():
c = xcollection.Collection()
assert isinstance(iter(c), typing.Iterator)


@pytest.mark.parametrize('data_vars', ['Tair', ['Tair']])
def test_choose_all(data_vars):
c = xcollection.Collection({'foo': ds, 'bar': ds})
d = c.choose(data_vars, mode='all')
assert c == d
assert set(d.keys()) == {'foo', 'bar'}


def test_choose_all_error():
c = xcollection.Collection({'foo': ds, 'bar': dsa})
with pytest.raises(KeyError):
c.choose('Tair', mode='all')


def test_choose_mode_error():
c = xcollection.Collection()
with pytest.raises(ValueError):
c.choose('Tair', mode='foo')


@pytest.mark.parametrize('data_vars', ['Tair', ['air']])
def test_choose_any(data_vars):
c = xcollection.Collection({'foo': ds, 'bar': dsa})
d = c.choose(data_vars, mode='any')
assert len(d) == 1
40 changes: 40 additions & 0 deletions xcollection/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections.abc import MutableMapping

import pydantic
import toolz
import xarray as xr


Expand Down Expand Up @@ -79,3 +80,42 @@ def values(self) -> typing.Iterable[xr.Dataset]:

def items(self) -> typing.Iterable[typing.Tuple[str, xr.Dataset]]:
return self.datasets.items()

def choose(
self, data_vars: typing.Union[str, typing.List[str]], *, mode: str = 'any'
) -> 'Collection':
"""Return a collection with datasets containing all or any of the specified data variables.
Parameters
----------
data_vars : str or list of str
The data variables to select on.
mode : str, optional
The selection mode. Must be one of 'all' or 'any'. Defaults to 'any'.
Returns
-------
Collection
A new collection containing only the selected datasets.
"""

_VALID_MODES = ['all', 'any']
if mode not in _VALID_MODES:
raise ValueError(f'Invalid mode: {mode}. Accepted modes are {_VALID_MODES}')

if isinstance(data_vars, str):
data_vars = [data_vars]

def _select_vars(dset):
try:
return dset[data_vars]
except KeyError:
if mode == 'all':
raise KeyError(f'No data variables: `{data_vars}` found in dataset: {dset!r}')

if mode == 'all':
result = toolz.valmap(_select_vars, self.datasets)
elif mode == 'any':
result = toolz.valfilter(_select_vars, self.datasets)

return type(self)(datasets=result)

0 comments on commit b56be0d

Please sign in to comment.