Skip to content

Commit

Permalink
Merge pull request #12 from dougiesquire/skip_nan_in_select_time_period
Browse files Browse the repository at this point in the history
Allowing `time_utils.select_time_period` to work on arrays containing NaNs
  • Loading branch information
dougiesquire authored Dec 13, 2021
2 parents eb2eff7 + 7487854 commit 77be826
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 69 deletions.
8 changes: 7 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ $ conda env create -f environment.yml
$ conda activate unseen
```

Aside: it is handy to install your conda environment as an ipykernel. This makes a kernel with the `unseen` environment available from within Jupyter and you won't have to restart Jupyter to effectuate any changes/updates you make to the environment (simply restarting the kernel will do):

```
python -m ipykernel install --user --name unseen --display-name "Python (unseen)"
```

4. Install `unseen` using the editable flag (meaning any changes you make to the package will be reflected directly in your environment):

```
Expand All @@ -46,7 +52,7 @@ You can also run `pre-commit` manually at any point to format your code:
pre-commit run --all-files
```

6. Start making and committing your edits, including adding tests to `unseen/tests` to check that your contributions are doing what they're suppose to. To run the test suite:
6. Start making and committing your edits, including adding docstrings to your functions and tests to `unseen/tests` to check that your contributions are doing what they're suppose to. Please try to follow [numpydoc style](https://numpydoc.readthedocs.io/en/latest/format.html) for docstrings. To run the test suite:

```
pytest unseen
Expand Down
34 changes: 24 additions & 10 deletions unseen/array_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,35 @@ def stack_by_init_date(
):
"""Stack timeseries array in inital date / lead time format.
Args:
ds (xarray Dataset)
init_dates (list) : Initial dates in YYYY-MM-DD format
n_lead_steps (int) : Maximum lead time
time_dim (str) : The name of the time dimension on ds
init_dim (str) : The name of the initial date dimension on the output array
lead_dim (str) : The name of the lead time dimension on the output array
Note, only initial dates that fall within the time range of the input
Parameters
----------
ds : xarray DataArray or Dataset
Input array containing a time dimension
period : list
List of initial dates of the same object type as the times in
the time dimension of ds
n_lead_steps: int
Maximum number of lead time steps
time_name: str
Name of the time dimension in ds
init_name: str
Name of the initial date dimension to create in the output
lead_name: str
Name of the lead time dimension to create in the output
Returns
-------
stacked : xarray DataArray or Dataset
Array with data stacked by specified initial dates and lead steps
Notes
-----
Only initial dates that fall within the time range of the input
timeseries are retained. Thus, inital dates prior to the time range of
the input timeseries that include data at longer lead times are not
included in the output dataset. To include these data, prepend the input
timeseries with nans so that the initial dates in question are present
in the time dimension of the input timeseries.
"""
# Only keep init dates that fall within available times
times = ds[time_dim]
Expand Down
70 changes: 70 additions & 0 deletions unseen/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import pytest

import numpy as np
import xarray as xr

import dask
import dask.array as dsa


def pytest_configure():
pytest.TIME_DIM = "time"
pytest.INIT_DIM = "init"
pytest.LEAD_DIM = "lead"


def empty_dask_array(shape, dtype=float, chunks=None):
"""A dask array that errors if you try to compute it
Stolen from https://github.com/xgcm/xhistogram/blob/master/xhistogram/test/fixtures.py
"""

def raise_if_computed():
raise ValueError("Triggered forbidden computation on dask array")

a = dsa.from_delayed(dask.delayed(raise_if_computed)(), shape, dtype)
if chunks is not None:
a = a.rechunk(chunks)
return a


@pytest.fixture()
def example_da_timeseries(request):
"""An example timeseries DataArray"""
time = xr.cftime_range(start="2000-01-01", end="2001-12-31", freq="D")
if request.param == "dask":
data = empty_dask_array((len(time),))
else:
data = np.array([t.toordinal() for t in time])
data -= data[0]
return xr.DataArray(data, coords=[time], dims=[pytest.TIME_DIM])


@pytest.fixture()
def example_da_forecast(request):
"""An example forecast DataArray"""
N_INIT = 24 # Keep at least 6
N_LEAD = 12 # Keep at least 6
START = "2000-01-01" # DO NOT CHANGE
init = xr.cftime_range(start=START, periods=N_INIT, freq="MS")
lead = range(N_LEAD)
time = [init.shift(i, freq="MS")[:N_LEAD] for i in range(len(init))]
if request.param == "dask":
data = empty_dask_array(
(
len(init),
len(lead),
)
)
else:
data = np.random.random(
(
len(init),
len(lead),
)
)
ds = xr.DataArray(
data, coords=[init, lead], dims=[pytest.INIT_DIM, pytest.LEAD_DIM]
)
return ds.assign_coords(
{pytest.TIME_DIM: ([pytest.INIT_DIM, pytest.LEAD_DIM], time)}
)
74 changes: 29 additions & 45 deletions unseen/tests/test_array_handling.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import pytest

import dask
import dask.array as dsa

import numpy as np
import xarray as xr

Expand All @@ -13,42 +10,14 @@
)


TIME_DIM = "time"
INIT_DIM = "init"
LEAD_DIM = "lead"


def empty_dask_array(shape, dtype=float, chunks=None):
"""A dask array that errors if you try to compute it
Stolen from https://github.com/xgcm/xhistogram/blob/master/xhistogram/test/fixtures.py
"""

def raise_if_computed():
raise ValueError("Triggered forbidden computation on dask array")

a = dsa.from_delayed(dask.delayed(raise_if_computed)(), shape, dtype)
if chunks is not None:
a = a.rechunk(chunks)
return a


@pytest.fixture()
def example_da_timeseries(request):
"""An example timeseries DataArray"""
time = xr.cftime_range(start="2000-01-01", end="2001-12-31", freq="D")
if request.param == "dask":
data = empty_dask_array((len(time),))
else:
data = np.array([t.toordinal() for t in time])
data -= data[0]
return xr.DataArray(data, coords=[time], dims=[TIME_DIM])


@pytest.mark.parametrize("example_da_timeseries", ["numpy"], indirect=True)
@pytest.mark.parametrize("offset", [0, 10])
@pytest.mark.parametrize("stride", [1, 10, "irregular"])
@pytest.mark.parametrize("n_lead_steps", [1, 10])
def test_stack_by_init_date(example_da_timeseries, offset, stride, n_lead_steps):
@pytest.mark.parametrize("data_object", ["DataArray", "Dataset"])
def test_stack_by_init_date(
example_da_timeseries, offset, stride, n_lead_steps, data_object
):
"""Test values returned by stack_by_init_date"""

def _np_stack_by_init_date(data, indexes, n_lead_steps):
Expand All @@ -61,48 +30,63 @@ def _np_stack_by_init_date(data, indexes, n_lead_steps):
return ver

data = example_da_timeseries
if data_object == "Dataset":
data = data.to_dataset(name="var")

if stride == "irregular":
indexes = np.concatenate(
([offset], np.random.randint(1, 20, size=1000))
).cumsum()
indexes = indexes[indexes < data.sizes[TIME_DIM]]
indexes = indexes[indexes < data.sizes[pytest.TIME_DIM]]
else:
indexes = range(offset, data.sizes[TIME_DIM], stride)
indexes = range(offset, data.sizes[pytest.TIME_DIM], stride)

init_dates = data[TIME_DIM][indexes]
init_dates = data[pytest.TIME_DIM][indexes]
res = stack_by_init_date(
data, init_dates, n_lead_steps, init_dim=INIT_DIM, lead_dim=LEAD_DIM
data,
init_dates,
n_lead_steps,
init_dim=pytest.INIT_DIM,
lead_dim=pytest.LEAD_DIM,
)

if data_object == "Dataset":
res = res["var"]
data = data["var"]

ver = _np_stack_by_init_date(data, indexes, n_lead_steps)

# Check that values are correct
npt.assert_allclose(res, ver)

# Check that init dates are correct
npt.assert_allclose(
xr.CFTimeIndex(init_dates.values).asi8, res.get_index(INIT_DIM).asi8
xr.CFTimeIndex(init_dates.values).asi8, res.get_index(pytest.INIT_DIM).asi8
)

# Check that times at lead zero match the init dates
npt.assert_allclose(
xr.CFTimeIndex(init_dates.values).asi8,
xr.CFTimeIndex(res[TIME_DIM].isel({LEAD_DIM: 0}).values).asi8,
xr.CFTimeIndex(res[pytest.TIME_DIM].isel({pytest.LEAD_DIM: 0}).values).asi8,
)


@pytest.mark.parametrize("example_da_timeseries", ["dask"], indirect=True)
def test_stack_by_init_date_dask(example_da_timeseries):
"""Test values returned by stack_by_init_date
For now just checks that doesn't trigger compute, but may want to add tests for
chunking etc in the future
For now just checks that doesn't trigger compute, but may want to add tests
for chunking etc in the future
"""

data = example_da_timeseries
n_lead_steps = 10
init_dates = data[TIME_DIM][::10]
init_dates = data[pytest.TIME_DIM][::10]

stack_by_init_date(
data, init_dates, n_lead_steps, init_dim=INIT_DIM, lead_dim=LEAD_DIM
data,
init_dates,
n_lead_steps,
init_dim=pytest.INIT_DIM,
lead_dim=pytest.LEAD_DIM,
)
32 changes: 32 additions & 0 deletions unseen/tests/test_time_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest

import numpy as np
from xarray.coding.times import cftime_to_nptime

from unseen.time_utils import (
select_time_period,
)


@pytest.mark.parametrize("example_da_forecast", ["numpy"], indirect=True)
@pytest.mark.parametrize("add_nans", [False, True])
@pytest.mark.parametrize("data_object", ["DataArray", "Dataset"])
def test_select_time_period(example_da_forecast, add_nans, data_object):
"""Test values returned by select_time_period"""
PERIOD = ["2000-06-01", "2001-06-01"]

data = example_da_forecast
if data_object == "Datatset":
data = data.to_dataset(name="var")

if add_nans:
time_nans = data[pytest.TIME_DIM].where(data[pytest.LEAD_DIM] > 3)
data = data.assign_coords({pytest.TIME_DIM: time_nans})

masked = select_time_period(data, PERIOD)

min_time = cftime_to_nptime(data[pytest.TIME_DIM].where(masked.notnull()).min())
max_time = cftime_to_nptime(data[pytest.TIME_DIM].where(masked.notnull()).max())

assert min_time >= np.datetime64(PERIOD[0])
assert max_time <= np.datetime64(PERIOD[1])
44 changes: 31 additions & 13 deletions unseen/time_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,32 +217,50 @@ def get_clim(ds, dims, time_period=None, groupby_init_month=False):
return clim


def select_time_period(ds, period):
def select_time_period(ds, period, time_name="time"):
"""Select a period of time.
Args:
ds (xarray DataSet or DataArray)
period (list) : Start and stop dates (in YYYY-MM-DD format)
Only works for cftime objects.
Parameters
----------
ds : xarray DataArray or Dataset
Input array containing time dimension or variable. The time
dimension of variable should be cftime but can contain nans.
period : list of str
Start and stop dates (in YYYY-MM-DD format)
time_name: str
Name of the time dimension, coordinate or variable
Returns
-------
masked : xarray DataArray
Array containing only times within provided period
"""

def _inbounds(t, bnds):
"""Check if time in bounds, allowing for nans"""
if t != t:
return False
else:
return (t >= bnds[0]) & (t <= bnds[1])

_vinbounds = np.vectorize(_inbounds)
_vinbounds.excluded.add(1)

check_date_format(period)
start, stop = period

if "time" in ds.dims:
selection = ds.sel({"time": slice(start, stop)})
elif "time" in ds.coords:
if time_name in ds.dims:
selection = ds.sel({time_name: slice(start, stop)})
elif time_name in ds.coords:
try:
calendar = ds["time"].calendar_type.lower()
calendar = ds[time_name].calendar_type.lower()
except AttributeError:
calendar = "standard"
time_bounds = xr.cftime_range(
start=start, end=stop, periods=2, freq=None, calendar=calendar
)
time_values = ds["time"].compute()
check_cftime(time_values)
mask = (time_values >= time_bounds[0]) & (time_values <= time_bounds[1])
time_values = ds[time_name].values
mask = _vinbounds(time_values, time_bounds)
selection = ds.where(mask)
else:
raise ValueError("No time axis for masking")
Expand Down

0 comments on commit 77be826

Please sign in to comment.