-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #12 from dougiesquire/skip_nan_in_select_time_period
Allowing `time_utils.select_time_period` to work on arrays containing NaNs
- Loading branch information
Showing
6 changed files
with
193 additions
and
69 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import pytest | ||
|
||
import numpy as np | ||
import xarray as xr | ||
|
||
import dask | ||
import dask.array as dsa | ||
|
||
|
||
def pytest_configure(): | ||
pytest.TIME_DIM = "time" | ||
pytest.INIT_DIM = "init" | ||
pytest.LEAD_DIM = "lead" | ||
|
||
|
||
def empty_dask_array(shape, dtype=float, chunks=None): | ||
"""A dask array that errors if you try to compute it | ||
Stolen from https://github.com/xgcm/xhistogram/blob/master/xhistogram/test/fixtures.py | ||
""" | ||
|
||
def raise_if_computed(): | ||
raise ValueError("Triggered forbidden computation on dask array") | ||
|
||
a = dsa.from_delayed(dask.delayed(raise_if_computed)(), shape, dtype) | ||
if chunks is not None: | ||
a = a.rechunk(chunks) | ||
return a | ||
|
||
|
||
@pytest.fixture() | ||
def example_da_timeseries(request): | ||
"""An example timeseries DataArray""" | ||
time = xr.cftime_range(start="2000-01-01", end="2001-12-31", freq="D") | ||
if request.param == "dask": | ||
data = empty_dask_array((len(time),)) | ||
else: | ||
data = np.array([t.toordinal() for t in time]) | ||
data -= data[0] | ||
return xr.DataArray(data, coords=[time], dims=[pytest.TIME_DIM]) | ||
|
||
|
||
@pytest.fixture() | ||
def example_da_forecast(request): | ||
"""An example forecast DataArray""" | ||
N_INIT = 24 # Keep at least 6 | ||
N_LEAD = 12 # Keep at least 6 | ||
START = "2000-01-01" # DO NOT CHANGE | ||
init = xr.cftime_range(start=START, periods=N_INIT, freq="MS") | ||
lead = range(N_LEAD) | ||
time = [init.shift(i, freq="MS")[:N_LEAD] for i in range(len(init))] | ||
if request.param == "dask": | ||
data = empty_dask_array( | ||
( | ||
len(init), | ||
len(lead), | ||
) | ||
) | ||
else: | ||
data = np.random.random( | ||
( | ||
len(init), | ||
len(lead), | ||
) | ||
) | ||
ds = xr.DataArray( | ||
data, coords=[init, lead], dims=[pytest.INIT_DIM, pytest.LEAD_DIM] | ||
) | ||
return ds.assign_coords( | ||
{pytest.TIME_DIM: ([pytest.INIT_DIM, pytest.LEAD_DIM], time)} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import pytest | ||
|
||
import numpy as np | ||
from xarray.coding.times import cftime_to_nptime | ||
|
||
from unseen.time_utils import ( | ||
select_time_period, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("example_da_forecast", ["numpy"], indirect=True) | ||
@pytest.mark.parametrize("add_nans", [False, True]) | ||
@pytest.mark.parametrize("data_object", ["DataArray", "Dataset"]) | ||
def test_select_time_period(example_da_forecast, add_nans, data_object): | ||
"""Test values returned by select_time_period""" | ||
PERIOD = ["2000-06-01", "2001-06-01"] | ||
|
||
data = example_da_forecast | ||
if data_object == "Datatset": | ||
data = data.to_dataset(name="var") | ||
|
||
if add_nans: | ||
time_nans = data[pytest.TIME_DIM].where(data[pytest.LEAD_DIM] > 3) | ||
data = data.assign_coords({pytest.TIME_DIM: time_nans}) | ||
|
||
masked = select_time_period(data, PERIOD) | ||
|
||
min_time = cftime_to_nptime(data[pytest.TIME_DIM].where(masked.notnull()).min()) | ||
max_time = cftime_to_nptime(data[pytest.TIME_DIM].where(masked.notnull()).max()) | ||
|
||
assert min_time >= np.datetime64(PERIOD[0]) | ||
assert max_time <= np.datetime64(PERIOD[1]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters