Skip to content

Commit

Permalink
enable grouped indices in grouped spi
Browse files Browse the repository at this point in the history
  • Loading branch information
valpesendorfer committed Mar 19, 2024
1 parent 367cfeb commit 9e81abd
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 54 deletions.
72 changes: 42 additions & 30 deletions hdc/algo/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
import dask.array as da
from dask.base import tokenize
import numpy as np
import pandas as pd
import xarray

from . import ops
from .dekad import Dekad
from .utils import to_linspace
from .utils import get_calibration_indices, to_linspace

__all__ = [
"Anomalies",
Expand Down Expand Up @@ -497,33 +496,31 @@ def spi(

tix = self._obj.get_index("time")

calstart_ix = 0
if calibration_start is not None:
calstart = pd.Timestamp(calibration_start)
if calstart > tix[-1]:
raise ValueError(
"Calibration start cannot be greater than last timestamp!"
)
(calstart_ix,) = tix.get_indexer([calstart], method="bfill")
if calibration_start is None:
calibration_start = tix[0]

calstop_ix = tix.size
if calibration_stop is not None:
calstop = pd.Timestamp(calibration_stop)
if calstop < tix[0]:
raise ValueError(
"Calibration stop cannot be smaller than first timestamp!"
)
(calstop_ix,) = tix.get_indexer([calstop], method="ffill") + 1
if calibration_stop is None:
calibration_stop = tix[-1]

if calstart_ix >= calstop_ix:
raise ValueError("calibration_start < calibration_stop!")
if calibration_start > tix[-1:]:
raise ValueError("Calibration start cannot be greater than last timestamp!")

if abs(calstop_ix - calstart_ix) <= 1:
raise ValueError(
"Timeseries too short for calculating SPI. Please adjust calibration period!"
)
if calibration_stop < tix[:1]:
raise ValueError("Calibration stop cannot be smaller than first timestamp!")

if groups is None:
calstart_ix, calstop_ix = get_calibration_indices(
tix, calibration_start, calibration_stop
)

if calstart_ix >= calstop_ix:
raise ValueError("calibration_start < calibration_stop!")

if abs(calstop_ix - calstart_ix) <= 1:
raise ValueError(
"Timeseries too short for calculating SPI. Please adjust calibration period!"
)

res = xarray.apply_ufunc(
gammastd_yxt,
self._obj,
Expand All @@ -540,22 +537,37 @@ def spi(
)

else:

groups, keys = to_linspace(np.array(groups, dtype="str"))

if len(groups) != len(self._obj.time):
raise ValueError("Need array of groups same length as time dimension!")

groups, keys = to_linspace(np.array(groups, dtype="str"))
groups = groups.astype("int16")
num_groups = len(keys)

cal_indices = get_calibration_indices(
tix, calibration_start, calibration_stop, groups, num_groups
)
# assert for mypy
assert isinstance(cal_indices, np.ndarray)

if np.any(cal_indices[:, 0] >= cal_indices[:, 1]):
raise ValueError("calibration_start < calibration_stop!")

if np.any(np.diff(cal_indices, axis=1) <= 1):
raise ValueError(
"Timeseries too short for calculating SPI. Please adjust calibration period!"
)

res = xarray.apply_ufunc(
gammastd_grp,
self._obj,
groups,
num_groups,
nodata,
calstart_ix,
calstop_ix,
input_core_dims=[["time"], ["grps"], [], [], [], []],
cal_indices,
input_core_dims=[["time"], ["grps"], [], [], ["start", "stop"]],
output_core_dims=[["time"]],
keep_attrs=True,
dask="parallelized",
Expand All @@ -564,8 +576,8 @@ def spi(

res.attrs.update(
{
"spi_calibration_start": str(tix[calstart_ix].date()),
"spi_calibration_stop": str(tix[calstop_ix - 1].date()),
"spi_calibration_start": str(tix[tix >= calibration_start][0]),
"spi_calibration_stop": str(tix[tix <= calibration_stop][-1]),
}
)

Expand Down
24 changes: 10 additions & 14 deletions hdc/algo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from typing import Iterable, List, Optional, Tuple, Union

import numpy as np
from numpy.typing import NDArray
import pandas as pd


def to_linspace(x) -> Tuple[np.ndarray, List[int]]:
def to_linspace(x) -> Tuple[NDArray[(np.int16,)], List[int]]:
"""Map input array to linear space.
Returns array with linear index (0 - n-1) and list of
Expand Down Expand Up @@ -41,24 +42,19 @@ def get_calibration_indices(
the second column is the stop index.
Parameters:
time (pd.DatetimeIndex): The time index.
start (str): The start time of the calibration range.
stop (str): The stop time of the calibration range.
groups (Optional[Iterable[Union[int, float, str]]]): Optional groups to consider for calibration.
num_groups (Optional[int]): Optional number of groups to consider for calibration.
Returns:
Union[Tuple[int, int], np.ndarray]: The calibration indices. If groups is None, returns a tuple of two indices.
If groups is not None, returns a numpy array of shape (num_groups, 2) containing the indices for each group.
time: The time index.
start: The start time of the calibration range.
stop: The stop time of the calibration range.
groups: Optional groups to consider for calibration.
num_groups: Optional number of groups to consider for calibration.
"""

def _get_ix(x: np.ndarray[np.datetime64], v: str, side: str):
return x.searchsorted(np.datetime64(v), side)
def _get_ix(x: NDArray[(np.datetime64,)], v: str, side: str):
return x.searchsorted(np.datetime64(v), side) # type: ignore

if groups is not None:
if num_groups is None:
num_groups = len(np.unique(groups))
num_groups = len(np.unique(np.array(groups)))
return np.array(
[
[
Expand Down
20 changes: 10 additions & 10 deletions tests/test_accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,18 +220,18 @@ def test_algo_spi_transp(darr, res_spi):

def test_algo_spi_attrs_default(darr):
_res = darr.hdc.algo.spi()
assert _res.attrs["spi_calibration_start"] == str(darr.time.dt.date[0].values)
assert _res.attrs["spi_calibration_stop"] == str(darr.time.dt.date[-1].values)
assert _res.attrs["spi_calibration_start"] == str(darr.time.to_index()[0])
assert _res.attrs["spi_calibration_stop"] == str(darr.time.to_index()[-1])


def test_algo_spi_attrs_start(darr):
_res = darr.hdc.algo.spi(calibration_start="2000-01-02")
assert _res.attrs["spi_calibration_start"] == "2000-01-11"
assert _res.attrs["spi_calibration_start"] == "2000-01-11 00:00:00"


def test_algo_spi_attrs_stop(darr):
_res = darr.hdc.algo.spi(calibration_stop="2000-02-09")
assert _res.attrs["spi_calibration_stop"] == "2000-01-31"
assert _res.attrs["spi_calibration_stop"] == "2000-01-31 00:00:00"


def test_algo_spi_decoupled_1(darr, res_spi):
Expand All @@ -242,8 +242,8 @@ def test_algo_spi_decoupled_1(darr, res_spi):
assert isinstance(_res, xr.DataArray)
np.testing.assert_array_equal(_res, res_spi)

assert _res.attrs["spi_calibration_start"] == "2000-01-01"
assert _res.attrs["spi_calibration_stop"] == "2000-02-10"
assert _res.attrs["spi_calibration_start"] == "2000-01-01 00:00:00"
assert _res.attrs["spi_calibration_stop"] == "2000-02-10 00:00:00"


def test_algo_spi_decoupled_2(darr):
Expand All @@ -261,8 +261,8 @@ def test_algo_spi_decoupled_2(darr):
assert isinstance(_res, xr.DataArray)
np.testing.assert_array_equal(_res, res_spi)

assert _res.attrs["spi_calibration_start"] == "2000-01-01"
assert _res.attrs["spi_calibration_stop"] == "2000-01-31"
assert _res.attrs["spi_calibration_start"] == "2000-01-01 00:00:00"
assert _res.attrs["spi_calibration_stop"] == "2000-01-31 00:00:00"


def test_algo_spi_decoupled_3(darr):
Expand All @@ -278,8 +278,8 @@ def test_algo_spi_decoupled_3(darr):
assert isinstance(_res, xr.DataArray)
np.testing.assert_array_equal(_res, res_spi)

assert _res.attrs["spi_calibration_start"] == "2000-01-11"
assert _res.attrs["spi_calibration_stop"] == "2000-02-10"
assert _res.attrs["spi_calibration_start"] == "2000-01-11 00:00:00"
assert _res.attrs["spi_calibration_stop"] == str(darr.time.to_index()[-1])


def test_algo_spi_nodata(darr):
Expand Down

0 comments on commit 9e81abd

Please sign in to comment.