Skip to content

Commit

Permalink
enable grouped indices in grouped spi
Browse files Browse the repository at this point in the history
  • Loading branch information
valpesendorfer committed Mar 19, 2024
1 parent 367cfeb commit 105a1d6
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 38 deletions.
65 changes: 37 additions & 28 deletions hdc/algo/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from . import ops
from .dekad import Dekad
from .utils import to_linspace
from .utils import get_calibration_indices, to_linspace

__all__ = [
"Anomalies",
Expand Down Expand Up @@ -497,33 +497,31 @@ def spi(

tix = self._obj.get_index("time")

calstart_ix = 0
if calibration_start is not None:
calstart = pd.Timestamp(calibration_start)
if calstart > tix[-1]:
raise ValueError(
"Calibration start cannot be greater than last timestamp!"
)
(calstart_ix,) = tix.get_indexer([calstart], method="bfill")
if calibration_start is None:
calibration_start = tix[0]

calstop_ix = tix.size
if calibration_stop is not None:
calstop = pd.Timestamp(calibration_stop)
if calstop < tix[0]:
raise ValueError(
"Calibration stop cannot be smaller than first timestamp!"
)
(calstop_ix,) = tix.get_indexer([calstop], method="ffill") + 1
if calibration_stop is None:
calibration_stop = tix[-1]

if calstart_ix >= calstop_ix:
raise ValueError("calibration_start < calibration_stop!")
if calibration_start > tix[-1:]:
raise ValueError("Calibration start cannot be greater than last timestamp!")

if abs(calstop_ix - calstart_ix) <= 1:
raise ValueError(
"Timeseries too short for calculating SPI. Please adjust calibration period!"
)
if calibration_stop < tix[:1]:
raise ValueError("Calibration stop cannot be smaller than first timestamp!")

if groups is None:
calstart_ix, calstop_ix = get_calibration_indices(
tix, calibration_start, calibration_stop
)

if calstart_ix >= calstop_ix:
raise ValueError("calibration_start < calibration_stop!")

if abs(calstop_ix - calstart_ix) <= 1:
raise ValueError(
"Timeseries too short for calculating SPI. Please adjust calibration period!"
)

res = xarray.apply_ufunc(
gammastd_yxt,
self._obj,
Expand All @@ -547,15 +545,26 @@ def spi(
groups = groups.astype("int16")
num_groups = len(keys)

cal_indices = get_calibration_indices(
tix, calibration_start, calibration_stop, groups, num_groups
)

if any(cal_indices[:, 0] >= cal_indices[:, 1]):
raise ValueError("calibration_start < calibration_stop!")

if np.any(np.diff(cal_indices, axis=1) <= 1):
raise ValueError(
"Timeseries too short for calculating SPI. Please adjust calibration period!"
)

res = xarray.apply_ufunc(
gammastd_grp,
self._obj,
groups,
num_groups,
nodata,
calstart_ix,
calstop_ix,
input_core_dims=[["time"], ["grps"], [], [], [], []],
cal_indices,
input_core_dims=[["time"], ["grps"], [], [], ["start", "stop"]],
output_core_dims=[["time"]],
keep_attrs=True,
dask="parallelized",
Expand All @@ -564,8 +573,8 @@ def spi(

res.attrs.update(
{
"spi_calibration_start": str(tix[calstart_ix].date()),
"spi_calibration_stop": str(tix[calstop_ix - 1].date()),
"spi_calibration_start": str(tix[tix >= calibration_start][0]),
"spi_calibration_stop": str(tix[tix <= calibration_stop][-1]),
}
)

Expand Down
20 changes: 10 additions & 10 deletions tests/test_accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,18 +220,18 @@ def test_algo_spi_transp(darr, res_spi):

def test_algo_spi_attrs_default(darr):
_res = darr.hdc.algo.spi()
assert _res.attrs["spi_calibration_start"] == str(darr.time.dt.date[0].values)
assert _res.attrs["spi_calibration_stop"] == str(darr.time.dt.date[-1].values)
assert _res.attrs["spi_calibration_start"] == str(darr.time.to_index()[0])
assert _res.attrs["spi_calibration_stop"] == str(darr.time.to_index()[-1])


def test_algo_spi_attrs_start(darr):
_res = darr.hdc.algo.spi(calibration_start="2000-01-02")
assert _res.attrs["spi_calibration_start"] == "2000-01-11"
assert _res.attrs["spi_calibration_start"] == "2000-01-11 00:00:00"


def test_algo_spi_attrs_stop(darr):
_res = darr.hdc.algo.spi(calibration_stop="2000-02-09")
assert _res.attrs["spi_calibration_stop"] == "2000-01-31"
assert _res.attrs["spi_calibration_stop"] == "2000-01-31 00:00:00"


def test_algo_spi_decoupled_1(darr, res_spi):
Expand All @@ -242,8 +242,8 @@ def test_algo_spi_decoupled_1(darr, res_spi):
assert isinstance(_res, xr.DataArray)
np.testing.assert_array_equal(_res, res_spi)

assert _res.attrs["spi_calibration_start"] == "2000-01-01"
assert _res.attrs["spi_calibration_stop"] == "2000-02-10"
assert _res.attrs["spi_calibration_start"] == "2000-01-01 00:00:00"
assert _res.attrs["spi_calibration_stop"] == "2000-02-10 00:00:00"


def test_algo_spi_decoupled_2(darr):
Expand All @@ -261,8 +261,8 @@ def test_algo_spi_decoupled_2(darr):
assert isinstance(_res, xr.DataArray)
np.testing.assert_array_equal(_res, res_spi)

assert _res.attrs["spi_calibration_start"] == "2000-01-01"
assert _res.attrs["spi_calibration_stop"] == "2000-01-31"
assert _res.attrs["spi_calibration_start"] == "2000-01-01 00:00:00"
assert _res.attrs["spi_calibration_stop"] == "2000-01-31 00:00:00"


def test_algo_spi_decoupled_3(darr):
Expand All @@ -278,8 +278,8 @@ def test_algo_spi_decoupled_3(darr):
assert isinstance(_res, xr.DataArray)
np.testing.assert_array_equal(_res, res_spi)

assert _res.attrs["spi_calibration_start"] == "2000-01-11"
assert _res.attrs["spi_calibration_stop"] == "2000-02-10"
assert _res.attrs["spi_calibration_start"] == "2000-01-11 00:00:00"
assert _res.attrs["spi_calibration_stop"] == str(darr.time.to_index()[-1])


def test_algo_spi_nodata(darr):
Expand Down

0 comments on commit 105a1d6

Please sign in to comment.