diff --git a/hdc/algo/accessors.py b/hdc/algo/accessors.py index b063026..6b797b6 100644 --- a/hdc/algo/accessors.py +++ b/hdc/algo/accessors.py @@ -1,18 +1,17 @@ """Xarray Accesor classes.""" -from typing import Iterable, List, Optional, Union +from typing import Iterable, List, Optional, Sequence, Union from warnings import warn from dask import is_dask_collection import dask.array as da from dask.base import tokenize import numpy as np -import pandas as pd import xarray from . import ops from .dekad import Dekad -from .utils import to_linspace +from .utils import get_calibration_indices, to_linspace __all__ = [ "Anomalies", @@ -497,33 +496,31 @@ def spi( tix = self._obj.get_index("time") - calstart_ix = 0 - if calibration_start is not None: - calstart = pd.Timestamp(calibration_start) - if calstart > tix[-1]: - raise ValueError( - "Calibration start cannot be greater than last timestamp!" - ) - (calstart_ix,) = tix.get_indexer([calstart], method="bfill") + if calibration_start is None: + calibration_start = tix[0] - calstop_ix = tix.size - if calibration_stop is not None: - calstop = pd.Timestamp(calibration_stop) - if calstop < tix[0]: - raise ValueError( - "Calibration stop cannot be smaller than first timestamp!" - ) - (calstop_ix,) = tix.get_indexer([calstop], method="ffill") + 1 + if calibration_stop is None: + calibration_stop = tix[-1] - if calstart_ix >= calstop_ix: - raise ValueError("calibration_start < calibration_stop!") + if calibration_start > tix[-1:]: + raise ValueError("Calibration start cannot be greater than last timestamp!") - if abs(calstop_ix - calstart_ix) <= 1: - raise ValueError( - "Timeseries too short for calculating SPI. Please adjust calibration period!" - ) + if calibration_stop < tix[:1]: + raise ValueError("Calibration stop cannot be smaller than first timestamp!") if groups is None: + calstart_ix, calstop_ix = get_calibration_indices( + tix, calibration_start, calibration_stop + ) + + if calstart_ix >= calstop_ix: + raise ValueError("calibration_start < calibration_stop!") + + if abs(calstop_ix - calstart_ix) <= 1: + raise ValueError( + "Timeseries too short for calculating SPI. Please adjust calibration period!" + ) + res = xarray.apply_ufunc( gammastd_yxt, self._obj, @@ -540,22 +537,37 @@ def spi( ) else: + + groups, keys = to_linspace(np.array(groups, dtype="str")) + if len(groups) != len(self._obj.time): raise ValueError("Need array of groups same length as time dimension!") - groups, keys = to_linspace(np.array(groups, dtype="str")) groups = groups.astype("int16") num_groups = len(keys) + cal_indices = get_calibration_indices( + tix, calibration_start, calibration_stop, groups, num_groups + ) + # assert for mypy + assert isinstance(cal_indices, np.ndarray) + + if np.any(cal_indices[:, 0] >= cal_indices[:, 1]): + raise ValueError("calibration_start < calibration_stop!") + + if np.any(np.diff(cal_indices, axis=1) <= 1): + raise ValueError( + "Timeseries too short for calculating SPI. Please adjust calibration period!" + ) + res = xarray.apply_ufunc( gammastd_grp, self._obj, groups, num_groups, nodata, - calstart_ix, - calstop_ix, - input_core_dims=[["time"], ["grps"], [], [], [], []], + cal_indices, + input_core_dims=[["time"], ["grps"], [], [], ["start", "stop"]], output_core_dims=[["time"]], keep_attrs=True, dask="parallelized", @@ -564,8 +576,8 @@ def spi( res.attrs.update( { - "spi_calibration_start": str(tix[calstart_ix].date()), - "spi_calibration_stop": str(tix[calstop_ix - 1].date()), + "spi_calibration_start": str(tix[tix >= calibration_start][0]), + "spi_calibration_stop": str(tix[tix <= calibration_stop][-1]), } ) diff --git a/hdc/algo/utils.py b/hdc/algo/utils.py index 8f4e34d..df8c992 100644 --- a/hdc/algo/utils.py +++ b/hdc/algo/utils.py @@ -3,10 +3,11 @@ from typing import Iterable, List, Optional, Tuple, Union import numpy as np +from numpy.typing import NDArray import pandas as pd -def to_linspace(x) -> Tuple[np.ndarray, List[int]]: +def to_linspace(x) -> Tuple[NDArray[(np.int16,)], List[int]]: """Map input array to linear space. Returns array with linear index (0 - n-1) and list of @@ -41,24 +42,19 @@ def get_calibration_indices( the second column is the stop index. Parameters: - time (pd.DatetimeIndex): The time index. - start (str): The start time of the calibration range. - stop (str): The stop time of the calibration range. - groups (Optional[Iterable[Union[int, float, str]]]): Optional groups to consider for calibration. - num_groups (Optional[int]): Optional number of groups to consider for calibration. - - Returns: - Union[Tuple[int, int], np.ndarray]: The calibration indices. If groups is None, returns a tuple of two indices. - If groups is not None, returns a numpy array of shape (num_groups, 2) containing the indices for each group. - + time: The time index. + start: The start time of the calibration range. + stop: The stop time of the calibration range. + groups: Optional groups to consider for calibration. + num_groups: Optional number of groups to consider for calibration. """ - def _get_ix(x: np.ndarray[np.datetime64], v: str, side: str): - return x.searchsorted(np.datetime64(v), side) + def _get_ix(x: NDArray[(np.datetime64,)], v: str, side: str): + return x.searchsorted(np.datetime64(v), side) # type: ignore if groups is not None: if num_groups is None: - num_groups = len(np.unique(groups)) + num_groups = len(np.unique(np.array(groups))) return np.array( [ [ diff --git a/tests/test_accessors.py b/tests/test_accessors.py index 0b82810..e0cb141 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -220,18 +220,18 @@ def test_algo_spi_transp(darr, res_spi): def test_algo_spi_attrs_default(darr): _res = darr.hdc.algo.spi() - assert _res.attrs["spi_calibration_start"] == str(darr.time.dt.date[0].values) - assert _res.attrs["spi_calibration_stop"] == str(darr.time.dt.date[-1].values) + assert _res.attrs["spi_calibration_start"] == str(darr.time.to_index()[0]) + assert _res.attrs["spi_calibration_stop"] == str(darr.time.to_index()[-1]) def test_algo_spi_attrs_start(darr): _res = darr.hdc.algo.spi(calibration_start="2000-01-02") - assert _res.attrs["spi_calibration_start"] == "2000-01-11" + assert _res.attrs["spi_calibration_start"] == "2000-01-11 00:00:00" def test_algo_spi_attrs_stop(darr): _res = darr.hdc.algo.spi(calibration_stop="2000-02-09") - assert _res.attrs["spi_calibration_stop"] == "2000-01-31" + assert _res.attrs["spi_calibration_stop"] == "2000-01-31 00:00:00" def test_algo_spi_decoupled_1(darr, res_spi): @@ -242,8 +242,8 @@ def test_algo_spi_decoupled_1(darr, res_spi): assert isinstance(_res, xr.DataArray) np.testing.assert_array_equal(_res, res_spi) - assert _res.attrs["spi_calibration_start"] == "2000-01-01" - assert _res.attrs["spi_calibration_stop"] == "2000-02-10" + assert _res.attrs["spi_calibration_start"] == "2000-01-01 00:00:00" + assert _res.attrs["spi_calibration_stop"] == "2000-02-10 00:00:00" def test_algo_spi_decoupled_2(darr): @@ -261,8 +261,8 @@ def test_algo_spi_decoupled_2(darr): assert isinstance(_res, xr.DataArray) np.testing.assert_array_equal(_res, res_spi) - assert _res.attrs["spi_calibration_start"] == "2000-01-01" - assert _res.attrs["spi_calibration_stop"] == "2000-01-31" + assert _res.attrs["spi_calibration_start"] == "2000-01-01 00:00:00" + assert _res.attrs["spi_calibration_stop"] == "2000-01-31 00:00:00" def test_algo_spi_decoupled_3(darr): @@ -278,8 +278,8 @@ def test_algo_spi_decoupled_3(darr): assert isinstance(_res, xr.DataArray) np.testing.assert_array_equal(_res, res_spi) - assert _res.attrs["spi_calibration_start"] == "2000-01-11" - assert _res.attrs["spi_calibration_stop"] == "2000-02-10" + assert _res.attrs["spi_calibration_start"] == "2000-01-11 00:00:00" + assert _res.attrs["spi_calibration_stop"] == str(darr.time.to_index()[-1]) def test_algo_spi_nodata(darr):