Skip to content

Commit

Permalink
reworking calibration range for SPI
Browse files Browse the repository at this point in the history
  • Loading branch information
valpesendorfer committed May 10, 2024
1 parent 7bbd550 commit c75efc8
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 49 deletions.
34 changes: 17 additions & 17 deletions hdc/algo/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,16 +463,16 @@ class PixelAlgorithms(AccessorBase):

def spi(
self,
calibration_start: Optional[str] = None,
calibration_stop: Optional[str] = None,
calibration_begin: Optional[str] = None,
calibration_end: Optional[str] = None,
nodata: Optional[Union[float, int]] = None,
groups: Optional[Iterable[Union[int, float, str]]] = None,
dtype="int16",
):
"""Calculate the SPI along the time dimension.
Calculates the Standardized Precipitation Index along the time dimension.
Optionally, a calibration start and / or stop date can be provided which
Optionally, a calibration begin and / or end date can be provided which
determine the part of the timeseries used to fit the gamma distribution.
`groups` can be supplied as list of group labels.
Expand All @@ -496,25 +496,25 @@ def spi(

tix = self._obj.get_index("time")

if calibration_start is None:
calibration_start = tix[0]
if calibration_begin is None:
calibration_begin = tix[0]

if calibration_stop is None:
calibration_stop = tix[-1]
if calibration_end is None:
calibration_end = tix[-1]

if calibration_start > tix[-1:]:
raise ValueError("Calibration start cannot be greater than last timestamp!")
if calibration_begin > tix[-1:]:
raise ValueError("Calibration begin cannot be greater than last timestamp!")

if calibration_stop < tix[:1]:
raise ValueError("Calibration stop cannot be smaller than first timestamp!")
if calibration_end < tix[:1]:
raise ValueError("Calibration end cannot be smaller than first timestamp!")

if groups is None:
calstart_ix, calstop_ix = get_calibration_indices(
tix, calibration_start, calibration_stop
tix, (calibration_begin, calibration_end)
)

if calstart_ix >= calstop_ix:
raise ValueError("calibration_start < calibration_stop!")
raise ValueError("calibration_begin < calibration_end!")

if abs(calstop_ix - calstart_ix) <= 1:
raise ValueError(
Expand Down Expand Up @@ -547,13 +547,13 @@ def spi(
num_groups = len(keys)

cal_indices = get_calibration_indices(
tix, calibration_start, calibration_stop, groups, num_groups
tix, (calibration_begin, calibration_end), groups, num_groups
)
# assert for mypy
assert isinstance(cal_indices, np.ndarray)

if np.any(cal_indices[:, 0] >= cal_indices[:, 1]):
raise ValueError("calibration_start < calibration_stop!")
raise ValueError("calibration_begin < calibration_end!")

if np.any(np.diff(cal_indices, axis=1) <= 1):
raise ValueError(
Expand All @@ -576,8 +576,8 @@ def spi(

res.attrs.update(
{
"spi_calibration_start": str(tix[tix >= calibration_start][0]),
"spi_calibration_stop": str(tix[tix <= calibration_stop][-1]),
"spi_calibration_begin": str(tix[tix >= calibration_begin][0]),
"spi_calibration_end": str(tix[tix <= calibration_end][-1]),
}
)

Expand Down
14 changes: 8 additions & 6 deletions hdc/algo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from numpy.typing import NDArray
import pandas as pd

DateType = Union[str, pd.Timestamp, np.datetime64]


def to_linspace(x) -> Tuple[NDArray[(np.int16,)], List[int]]:
"""Map input array to linear space.
Expand All @@ -27,8 +29,7 @@ def to_linspace(x) -> Tuple[NDArray[(np.int16,)], List[int]]:

def get_calibration_indices(
time: pd.DatetimeIndex,
start: Union[str, pd.Timestamp],
stop: Union[str, pd.Timestamp],
calibration_range: Tuple[DateType, DateType],
groups: Optional[Iterable[Union[int, float, str]]] = None,
num_groups: Optional[int] = None,
) -> Union[Tuple[int, int], np.ndarray]:
Expand All @@ -48,8 +49,9 @@ def get_calibration_indices(
groups: Optional groups to consider for calibration.
num_groups: Optional number of groups to consider for calibration.
"""
begin, end = calibration_range

def _get_ix(x: NDArray[(np.datetime64,)], v: str, side: str):
def _get_ix(x: NDArray[(np.datetime64,)], v: DateType, side: str):
return x.searchsorted(np.datetime64(v), side) # type: ignore

if groups is not None:
Expand All @@ -58,12 +60,12 @@ def _get_ix(x: NDArray[(np.datetime64,)], v: str, side: str):
return np.array(
[
[
_get_ix(time[groups == ix].values, start, "left"),
_get_ix(time[groups == ix].values, stop, "right"),
_get_ix(time[groups == ix].values, begin, "left"),
_get_ix(time[groups == ix].values, end, "right"),
]
for ix in range(num_groups)
],
dtype="int16",
)

return _get_ix(time.values, start, "left"), _get_ix(time.values, stop, "right")
return _get_ix(time.values, begin, "left"), _get_ix(time.values, end, "right")
42 changes: 21 additions & 21 deletions tests/test_accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,30 +220,30 @@ def test_algo_spi_transp(darr, res_spi):

def test_algo_spi_attrs_default(darr):
_res = darr.hdc.algo.spi()
assert _res.attrs["spi_calibration_start"] == str(darr.time.to_index()[0])
assert _res.attrs["spi_calibration_stop"] == str(darr.time.to_index()[-1])
assert _res.attrs["spi_calibration_begin"] == str(darr.time.to_index()[0])
assert _res.attrs["spi_calibration_end"] == str(darr.time.to_index()[-1])


def test_algo_spi_attrs_start(darr):
_res = darr.hdc.algo.spi(calibration_start="2000-01-02")
assert _res.attrs["spi_calibration_start"] == "2000-01-11 00:00:00"
_res = darr.hdc.algo.spi(calibration_begin="2000-01-02")
assert _res.attrs["spi_calibration_begin"] == "2000-01-11 00:00:00"


def test_algo_spi_attrs_stop(darr):
_res = darr.hdc.algo.spi(calibration_stop="2000-02-09")
assert _res.attrs["spi_calibration_stop"] == "2000-01-31 00:00:00"
_res = darr.hdc.algo.spi(calibration_end="2000-02-09")
assert _res.attrs["spi_calibration_end"] == "2000-01-31 00:00:00"


def test_algo_spi_decoupled_1(darr, res_spi):
_res = darr.hdc.algo.spi(
calibration_start="2000-01-01", calibration_stop="2000-02-10"
calibration_begin="2000-01-01", calibration_end="2000-02-10"
)

assert isinstance(_res, xr.DataArray)
np.testing.assert_array_equal(_res, res_spi)

assert _res.attrs["spi_calibration_start"] == "2000-01-01 00:00:00"
assert _res.attrs["spi_calibration_stop"] == "2000-02-10 00:00:00"
assert _res.attrs["spi_calibration_begin"] == "2000-01-01 00:00:00"
assert _res.attrs["spi_calibration_end"] == "2000-02-10 00:00:00"


def test_algo_spi_decoupled_2(darr):
Expand All @@ -255,14 +255,14 @@ def test_algo_spi_decoupled_2(darr):
)

_res = darr.hdc.algo.spi(
calibration_start="2000-01-01", calibration_stop="2000-01-31"
calibration_begin="2000-01-01", calibration_end="2000-01-31"
)

assert isinstance(_res, xr.DataArray)
np.testing.assert_array_equal(_res, res_spi)

assert _res.attrs["spi_calibration_start"] == "2000-01-01 00:00:00"
assert _res.attrs["spi_calibration_stop"] == "2000-01-31 00:00:00"
assert _res.attrs["spi_calibration_begin"] == "2000-01-01 00:00:00"
assert _res.attrs["spi_calibration_end"] == "2000-01-31 00:00:00"


def test_algo_spi_decoupled_3(darr):
Expand All @@ -273,13 +273,13 @@ def test_algo_spi_decoupled_3(darr):
]
)

_res = darr.hdc.algo.spi(calibration_start="2000-01-11")
_res = darr.hdc.algo.spi(calibration_begin="2000-01-11")

assert isinstance(_res, xr.DataArray)
np.testing.assert_array_equal(_res, res_spi)

assert _res.attrs["spi_calibration_start"] == "2000-01-11 00:00:00"
assert _res.attrs["spi_calibration_stop"] == str(darr.time.to_index()[-1])
assert _res.attrs["spi_calibration_begin"] == "2000-01-11 00:00:00"
assert _res.attrs["spi_calibration_end"] == str(darr.time.to_index()[-1])


def test_algo_spi_nodata(darr):
Expand All @@ -292,30 +292,30 @@ def test_algo_spi_nodata(darr):
def test_algo_spi_decoupled_err_1(darr):
with pytest.raises(ValueError):
_res = darr.hdc.algo.spi(
calibration_start="2000-03-01",
calibration_begin="2000-03-01",
)


def test_algo_spi_decoupled_err_2(darr):
with pytest.raises(ValueError):
_res = darr.hdc.algo.spi(
calibration_stop="1999-01-01",
calibration_end="1999-01-01",
)


def test_algo_spi_decoupled_err_3(darr):
with pytest.raises(ValueError):
_res = darr.hdc.algo.spi(
calibration_start="2000-01-01",
calibration_stop="2000-01-01",
calibration_begin="2000-01-01",
calibration_end="2000-01-01",
)


def test_algo_spi_decoupled_err_4(darr):
with pytest.raises(ValueError):
_res = darr.hdc.algo.spi(
calibration_start="2000-02-01",
calibration_stop="2000-01-01",
calibration_begin="2000-02-01",
calibration_end="2000-01-01",
)


Expand Down
13 changes: 8 additions & 5 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def test_to_linspace(input_array, expected_output):
def test_get_calibration_indices():
tix = pd.date_range("2010-01-01", "2010-12-31")

assert get_calibration_indices(tix, tix[0], tix[-1]) == (0, 365)
assert get_calibration_indices(tix, "2010-01-01", "2010-12-31") == (0, 365)
assert get_calibration_indices(tix, "2010-01-15", "2010-12-15") == (14, 349)
assert get_calibration_indices(tix, (tix[0], tix[-1])) == (0, 365)
assert get_calibration_indices(tix, ("2010-01-01", "2010-12-31")) == (0, 365)
assert get_calibration_indices(tix, ("2010-01-15", "2010-12-15")) == (14, 349)

# groups
res = np.array(
Expand All @@ -58,14 +58,17 @@ def test_get_calibration_indices():

np.testing.assert_array_equal(
get_calibration_indices(
tix, "2010-01-15", "2010-12-15", groups=tix.month.values - 1, num_groups=12
tix,
("2010-01-15", "2010-12-15"),
groups=tix.month.values - 1,
num_groups=12,
),
res,
)

np.testing.assert_array_equal(
get_calibration_indices(
tix, "2010-01-15", "2010-12-15", groups=tix.month.values - 1
tix, ("2010-01-15", "2010-12-15"), groups=tix.month.values - 1
),
res,
)

0 comments on commit c75efc8

Please sign in to comment.