Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix & improve grouped spi #50

Merged
merged 9 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hdc/algo/_version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Version only in this file."""

__version__ = "0.4.0"
__version__ = "0.5.0"
88 changes: 50 additions & 38 deletions hdc/algo/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import dask.array as da
from dask.base import tokenize
import numpy as np
import pandas as pd
import xarray

from . import ops
from .dekad import Dekad
from .utils import get_calibration_indices, to_linspace

__all__ = [
"Anomalies",
Expand Down Expand Up @@ -466,7 +466,7 @@ def spi(
calibration_start: Optional[str] = None,
calibration_stop: Optional[str] = None,
nodata: Optional[Union[float, int]] = None,
groups: Optional[Iterable[int]] = None,
groups: Optional[Iterable[Union[int, float, str]]] = None,
dtype="int16",
):
"""Calculate the SPI along the time dimension.
Expand All @@ -475,8 +475,7 @@ def spi(
Optionally, a calibration start and / or stop date can be provided which
determine the part of the timeseries used to fit the gamma distribution.

`groups` can be supplied as list of group labels (attention, they are required
to be in format {0..n-1} where n is the number of unique groups.
`groups` can be supplied as list of group labels.
If `groups` is supplied, the SPI will be computed for each individual group.
This is intended to be used when SPI should be calculated for specific timesteps.
"""
Expand All @@ -497,33 +496,31 @@ def spi(

tix = self._obj.get_index("time")

calstart_ix = 0
if calibration_start is not None:
calstart = pd.Timestamp(calibration_start)
if calstart > tix[-1]:
raise ValueError(
"Calibration start cannot be greater than last timestamp!"
)
(calstart_ix,) = tix.get_indexer([calstart], method="bfill")
if calibration_start is None:
calibration_start = tix[0]

calstop_ix = tix.size
if calibration_stop is not None:
calstop = pd.Timestamp(calibration_stop)
if calstop < tix[0]:
raise ValueError(
"Calibration stop cannot be smaller than first timestamp!"
)
(calstop_ix,) = tix.get_indexer([calstop], method="ffill") + 1
if calibration_stop is None:
calibration_stop = tix[-1]

if calstart_ix >= calstop_ix:
raise ValueError("calibration_start < calibration_stop!")
if calibration_start > tix[-1:]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this comparing str > [str]?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comparing str to pd.Timestamp

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why tix[-1:] and not tix[-1]?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs to remain a pd.DateTimeIndex with one element to work ... maybe bit hacky & counterintuitive

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could explicitly convert the string to a date object (pd.Timestamp or whatever) and then compare the elements if that's clearer

raise ValueError("Calibration start cannot be greater than last timestamp!")

if abs(calstop_ix - calstart_ix) <= 1:
raise ValueError(
"Timeseries too short for calculating SPI. Please adjust calibration period!"
)
if calibration_stop < tix[:1]:
raise ValueError("Calibration stop cannot be smaller than first timestamp!")

if groups is None:
calstart_ix, calstop_ix = get_calibration_indices(
tix, calibration_start, calibration_stop
)

if calstart_ix >= calstop_ix:
raise ValueError("calibration_start < calibration_stop!")

if abs(calstop_ix - calstart_ix) <= 1:
raise ValueError(
"Timeseries too short for calculating SPI. Please adjust calibration period!"
)

res = xarray.apply_ufunc(
gammastd_yxt,
self._obj,
Expand All @@ -540,22 +537,37 @@ def spi(
)

else:
groups = np.array(groups) if not isinstance(groups, np.ndarray) else groups
num_groups = np.unique(groups).size

if not groups.dtype.name == "int16":
warn("Casting groups to int16!")
groups = groups.astype("int16")
groups, keys = to_linspace(np.array(groups, dtype="str"))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Kirill888 the idea is that the user passes in a list / array / tuple / ... containing group labels which puts the time dimension into bins. The downstream function requires these groups to be in a linear space, so from 0 to len(groups)-1 which was the expected input previously.

To take this burden off the user, this allows any labels and converts them to a str array, which is then converted to linear space. Maybe not most efficient, but for these applications totally fine - does that make sense for you?

For a practical example, we'll be using this to calculate SPIs for the full timeseries, grouping the time dimension into dekads per year. The user can then pass simply xx.time.dekad.yidx in without having to make sure the groups array is 0 indexed.


if len(groups) != len(self._obj.time):
raise ValueError("Need array of groups same length as time dimension!")

groups = groups.astype("int16")
num_groups = len(keys)

cal_indices = get_calibration_indices(
tix, calibration_start, calibration_stop, groups, num_groups
)
valpesendorfer marked this conversation as resolved.
Show resolved Hide resolved
# assert for mypy
assert isinstance(cal_indices, np.ndarray)

if np.any(cal_indices[:, 0] >= cal_indices[:, 1]):
raise ValueError("calibration_start < calibration_stop!")

if np.any(np.diff(cal_indices, axis=1) <= 1):
raise ValueError(
"Timeseries too short for calculating SPI. Please adjust calibration period!"
)

res = xarray.apply_ufunc(
gammastd_grp,
self._obj,
groups,
num_groups,
nodata,
calstart_ix,
calstop_ix,
input_core_dims=[["time"], ["grps"], [], [], [], []],
cal_indices,
input_core_dims=[["time"], ["grps"], [], [], ["start", "stop"]],
output_core_dims=[["time"]],
keep_attrs=True,
dask="parallelized",
Expand All @@ -564,8 +576,8 @@ def spi(

res.attrs.update(
{
"spi_calibration_start": str(tix[calstart_ix].date()),
"spi_calibration_stop": str(tix[calstop_ix - 1].date()),
"spi_calibration_start": str(tix[tix >= calibration_start][0]),
"spi_calibration_stop": str(tix[tix <= calibration_stop][-1]),
}
)

Expand Down Expand Up @@ -806,7 +818,7 @@ def mean(

# set null values to nodata value
xx = xx.where(xx.notnull(), xx.nodata)

attrs = xx.attrs
num_zones = len(zone_ids)
dims = (xx.dims[0], dim_name, "stat")
coords = {
Expand Down Expand Up @@ -849,7 +861,7 @@ def mean(
)

return xarray.DataArray(
data=data, dims=dims, coords=coords, attrs={}, name=name
data=data, dims=dims, coords=coords, attrs=attrs, name=name
)


Expand Down
15 changes: 11 additions & 4 deletions hdc/algo/ops/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,21 +241,28 @@ def gammastd_yxt(
@lazycompile(
guvectorize(
[
"(int16[:], int16[:], float64, float64, float64, float64, int16[:])",
"(float32[:], int16[:], float64, float64, float64, float64, int16[:])",
"(int16[:], int16[:], float64, float64, int16[:, :], int16[:])",
"(float32[:], int16[:], float64, float64, int16[:, :], int16[:])",
],
"(n),(m),(),(),(),() -> (n)",
"(n),(m),(),(),(o, p) -> (n)",
)
)
def gammastd_grp(xx, groups, num_groups, nodata, cal_start, cal_stop, yy):
def gammastd_grp(xx, groups, num_groups, nodata, cal_indices, yy):
"""Calculate the gammastd for specific groups.

This calculates gammastd across xx for indivual groups
defined in `groups`. These need to be in ascending order from
0 to num_groups - 1.

`cal_indices` is an array of shape (num_groups, 2) where each row
contains the start and end index for the calibration period for each group.
"""
for grp in range(num_groups):
grp_ix = groups == grp

cal_start = cal_indices[grp, 0]
cal_stop = cal_indices[grp, 1]

pix = xx[grp_ix]
if (pix != nodata).sum() == 0:
yy[grp_ix] = nodata
Expand Down
50 changes: 48 additions & 2 deletions hdc/algo/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""hcd-algo utility functions."""

from typing import List, Tuple
from typing import Iterable, List, Optional, Tuple, Union

import numpy as np
from numpy.typing import NDArray
import pandas as pd


def to_linspace(x) -> Tuple[np.ndarray, List[int]]:
def to_linspace(x) -> Tuple[NDArray[(np.int16,)], List[int]]:
"""Map input array to linear space.

Returns array with linear index (0 - n-1) and list of
Expand All @@ -21,3 +23,47 @@ def to_linspace(x) -> Tuple[np.ndarray, List[int]]:
new_pix = np.where(mask, values[idx], 0)

return new_pix, list(keys)


def get_calibration_indices(
time: pd.DatetimeIndex,
start: Union[str, pd.Timestamp],
stop: Union[str, pd.Timestamp],
groups: Optional[Iterable[Union[int, float, str]]] = None,
num_groups: Optional[int] = None,
) -> Union[Tuple[int, int], np.ndarray]:
"""
Get the calibration indices for a given time range.

This function returns indices for a calibration period (e.g. used for SPI)
given an index of timestamps and a start & stop date.
If groups are provided, the indices are returned per group, as an
array of shape (num_groups, 2) where the first column is the start index and
the second column is the stop index.

Parameters:
time: The time index.
start: The start time of the calibration range.
stop: The stop time of the calibration range.
groups: Optional groups to consider for calibration.
num_groups: Optional number of groups to consider for calibration.
"""

def _get_ix(x: NDArray[(np.datetime64,)], v: str, side: str):
return x.searchsorted(np.datetime64(v), side) # type: ignore

if groups is not None:
if num_groups is None:
num_groups = len(np.unique(np.array(groups)))
return np.array(
[
[
_get_ix(time[groups == ix].values, start, "left"),
_get_ix(time[groups == ix].values, stop, "right"),
]
for ix in range(num_groups)
],
dtype="int16",
)

return _get_ix(time.values, start, "left"), _get_ix(time.values, stop, "right")
31 changes: 21 additions & 10 deletions tests/test_accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,12 @@ def test_algo_spi_grouped(darr, res_spi):
np.testing.assert_array_equal(_res, res_spi)


def test_algo_spi_grouped_2(darr, res_spi):
_res = darr.astype("float32").hdc.algo.spi(groups=["D1", "D1", "D1", "D1", "D1"])
assert isinstance(_res, xr.DataArray)
np.testing.assert_array_equal(_res, res_spi)


def test_algo_spi_transp(darr, res_spi):
_darr = darr.transpose(..., "time")
_res = _darr.hdc.algo.spi()
Expand All @@ -214,18 +220,18 @@ def test_algo_spi_transp(darr, res_spi):

def test_algo_spi_attrs_default(darr):
_res = darr.hdc.algo.spi()
assert _res.attrs["spi_calibration_start"] == str(darr.time.dt.date[0].values)
assert _res.attrs["spi_calibration_stop"] == str(darr.time.dt.date[-1].values)
assert _res.attrs["spi_calibration_start"] == str(darr.time.to_index()[0])
assert _res.attrs["spi_calibration_stop"] == str(darr.time.to_index()[-1])


def test_algo_spi_attrs_start(darr):
_res = darr.hdc.algo.spi(calibration_start="2000-01-02")
assert _res.attrs["spi_calibration_start"] == "2000-01-11"
assert _res.attrs["spi_calibration_start"] == "2000-01-11 00:00:00"


def test_algo_spi_attrs_stop(darr):
_res = darr.hdc.algo.spi(calibration_stop="2000-02-09")
assert _res.attrs["spi_calibration_stop"] == "2000-01-31"
assert _res.attrs["spi_calibration_stop"] == "2000-01-31 00:00:00"


def test_algo_spi_decoupled_1(darr, res_spi):
Expand All @@ -236,8 +242,8 @@ def test_algo_spi_decoupled_1(darr, res_spi):
assert isinstance(_res, xr.DataArray)
np.testing.assert_array_equal(_res, res_spi)

assert _res.attrs["spi_calibration_start"] == "2000-01-01"
assert _res.attrs["spi_calibration_stop"] == "2000-02-10"
assert _res.attrs["spi_calibration_start"] == "2000-01-01 00:00:00"
assert _res.attrs["spi_calibration_stop"] == "2000-02-10 00:00:00"


def test_algo_spi_decoupled_2(darr):
Expand All @@ -255,8 +261,8 @@ def test_algo_spi_decoupled_2(darr):
assert isinstance(_res, xr.DataArray)
np.testing.assert_array_equal(_res, res_spi)

assert _res.attrs["spi_calibration_start"] == "2000-01-01"
assert _res.attrs["spi_calibration_stop"] == "2000-01-31"
assert _res.attrs["spi_calibration_start"] == "2000-01-01 00:00:00"
assert _res.attrs["spi_calibration_stop"] == "2000-01-31 00:00:00"


def test_algo_spi_decoupled_3(darr):
Expand All @@ -272,8 +278,8 @@ def test_algo_spi_decoupled_3(darr):
assert isinstance(_res, xr.DataArray)
np.testing.assert_array_equal(_res, res_spi)

assert _res.attrs["spi_calibration_start"] == "2000-01-11"
assert _res.attrs["spi_calibration_stop"] == "2000-02-10"
assert _res.attrs["spi_calibration_start"] == "2000-01-11 00:00:00"
assert _res.attrs["spi_calibration_stop"] == str(darr.time.to_index()[-1])


def test_algo_spi_nodata(darr):
Expand Down Expand Up @@ -965,6 +971,7 @@ def test_zonal_mean(darr, zones):
assert list(x.coords["stat"].values) == ["mean", "valid"]
np.testing.assert_almost_equal(x, res)
assert x.dtype == res.dtype
assert x.nodata == darr.nodata


def test_zonal_mean_nodata(darr, zones):
Expand All @@ -985,6 +992,7 @@ def test_zonal_mean_nodata(darr, zones):
x = darr.hdc.zonal.mean(zones, z_ids)
np.testing.assert_almost_equal(x, res)
assert x.dtype == res.dtype
assert x.nodata == darr.nodata


def test_zonal_mean_nodata_nan(darr, zones):
Expand All @@ -994,6 +1002,7 @@ def test_zonal_mean_nodata_nan(darr, zones):
x = darr.hdc.zonal.mean(zones, z_ids)
assert np.isnan(x.data[[0, -1], :, 0]).all()
assert np.all(x.data[[0, -1], :, 1] == 0)
assert x.nodata == darr.nodata


def test_zonal_mean_nodata_nan_float(darr, zones):
Expand All @@ -1015,6 +1024,7 @@ def test_zonal_mean_nodata_nan_float(darr, zones):
x = darr.hdc.zonal.mean(zones, z_ids, dtype="float64")
np.testing.assert_almost_equal(x, res)
assert x.dtype == res.dtype
assert x.nodata == darr.nodata


def test_zonal_zone_nodata_nan(darr, zones):
Expand All @@ -1034,6 +1044,7 @@ def test_zonal_zone_nodata_nan(darr, zones):
x = darr.hdc.zonal.mean(zones, z_ids, dim_name="foo", dtype="float64")
np.testing.assert_almost_equal(x, res)
assert x.dtype == res.dtype
assert x.nodata == darr.nodata


def test_zonal_dimname(darr, zones):
Expand Down
3 changes: 2 additions & 1 deletion tests/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ def test_gammastd_selfit_2(ts):
def test_gammastd_grp(ts):
tts = np.repeat(ts, 5).astype("float32")
grps = np.tile(np.arange(5), 10).astype("int16")
xspi = gammastd_grp(tts, grps, np.unique(grps).size, 0, 0, 10)
indices = np.array([[0, 10]] * 10, dtype="int16")
xspi = gammastd_grp(tts, grps, np.unique(grps).size, 0, indices)

res = [
-0.38238713,
Expand Down
Loading
Loading