Skip to content

Commit

Permalink
Merge pull request #46 from WFP-VAM/spi-updates
Browse files Browse the repository at this point in the history
Grouped operations (and more)
  • Loading branch information
valpesendorfer authored Jan 22, 2024
2 parents 72bc76c + 14a7392 commit 15c0297
Show file tree
Hide file tree
Showing 6 changed files with 462 additions and 56 deletions.
2 changes: 1 addition & 1 deletion hdc/algo/_version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
"""Version only in this file."""
__version__ = "0.2.1"
__version__ = "0.3.0"
191 changes: 173 additions & 18 deletions hdc/algo/accessors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Xarray Accesor classes."""
from typing import List, Optional, Union
from typing import Iterable, List, Optional, Union
from warnings import warn

from dask import is_dask_collection
Expand Down Expand Up @@ -100,10 +100,26 @@ def midx(self):
def yidx(self):
return self._tseries.apply(lambda x: self._period_cls(x).yidx).to_xarray()

@property
def ndays(self):
return self._tseries.apply(lambda x: self._period_cls(x).ndays).to_xarray()

@property
def label(self):
return self._tseries.apply(lambda x: str(self._period_cls(x))).to_xarray()

@property
def start_date(self):
return self._tseries.apply(lambda x: self._period_cls(x).start_date).to_xarray()

@property
def end_date(self):
return self._tseries.apply(lambda x: self._period_cls(x).end_date).to_xarray()

@property
def raw(self):
return self._tseries.apply(lambda x: self._period_cls(x).raw).to_xarray()


@xarray.register_dataset_accessor("dekad")
@xarray.register_dataarray_accessor("dekad")
Expand Down Expand Up @@ -444,12 +460,35 @@ def spi(
self,
calibration_start: Optional[str] = None,
calibration_stop: Optional[str] = None,
nodata: Optional[Union[float, int]] = None,
groups: Optional[Iterable[int]] = None,
dtype="int16",
):
"""Calculate the SPI along the time dimension."""
"""Calculate the SPI along the time dimension.
Calculates the Standardized Precipitation Index along the time dimension.
Optionally, a calibration start and / or stop date can be provided which
determine the part of the timeseries used to fit the gamma distribution.
`groups` can be supplied as list of group labels (attention, they are required
to be in format {0..n-1} where n is the number of unique groups.
If `groups` is supplied, the SPI will be computed for each individual group.
This is intended to be used when SPI should be calculated for specific timesteps.
"""
if not self._check_for_timedim():
raise MissingTimeError("SPI requires a time dimension!")

from .ops.stats import gammastd_yxt # pylint: disable=import-outside-toplevel
if nodata is None:
if (nodata := self._obj.attrs.get("nodata")) is None:
raise ValueError(
"Need nodata attribute defined, or nodata argument provided."
)

# pylint: disable=import-outside-toplevel
from .ops.stats import (
gammastd_yxt,
gammastd_grp,
)

tix = self._obj.get_index("time")

Expand Down Expand Up @@ -479,17 +518,44 @@ def spi(
"Timeseries too short for calculating SPI. Please adjust calibration period!"
)

res = xarray.apply_ufunc(
gammastd_yxt,
self._obj,
kwargs={
"cal_start": calstart_ix,
"cal_stop": calstop_ix,
},
input_core_dims=[["time"]],
output_core_dims=[["time"]],
dask="parallelized",
)
if groups is None:
res = xarray.apply_ufunc(
gammastd_yxt,
self._obj,
kwargs={
"nodata": nodata,
"cal_start": calstart_ix,
"cal_stop": calstop_ix,
},
input_core_dims=[["time"]],
output_core_dims=[["time"]],
keep_attrs=True,
dask="parallelized",
dask_gufunc_kwargs={"meta": self._obj.data.astype(dtype)},
)

else:
groups = np.array(groups) if not isinstance(groups, np.ndarray) else groups
num_groups = np.unique(groups).size

if not groups.dtype.name == "int16":
warn("Casting groups to int16!")
groups = groups.astype("int16")

res = xarray.apply_ufunc(
gammastd_grp,
self._obj,
groups,
num_groups,
nodata,
calstart_ix,
calstop_ix,
input_core_dims=[["time"], ["grps"], [], [], [], []],
output_core_dims=[["time"]],
keep_attrs=True,
dask="parallelized",
dask_gufunc_kwargs={"meta": self._obj.data.astype(dtype)},
)

res.attrs.update(
{
Expand Down Expand Up @@ -523,6 +589,7 @@ def lroo(self):
input_core_dims=[["time"]],
dask="parallelized",
output_dtypes=["uint8"],
keep_attrs=True,
)

def autocorr(self):
Expand All @@ -533,8 +600,7 @@ def autocorr(self):
xarray.DataArray with lag1 autocorrelation
"""
xx = self._obj
nodata = xx.attrs.get("nodata", None)
if nodata is None:
if (nodata := xx.attrs.get("nodata", None)) is None:
warn("Calculating autocorr without nodata value defined!")
if xx.dims[0] == "time":
# I don't know how to tell xarray's map_blocks about
Expand Down Expand Up @@ -606,6 +672,90 @@ def mktrend(self):
x.trend.attrs["nodata"] = -2
return x

def mean_grp(
self,
groups: Iterable[int],
nodata: Optional[Union[int, float]] = None,
):
"""Calculate mean over groups along time dimension.
This calculates a simple average over groups along the time
dimension. The groups are identified by an int16 array, and
the **need to be in ascending order from 0 to n-1**!
The function will return an array of original size with averages
at the respective positions.
"""
if not self._check_for_timedim():
raise MissingTimeError("Grouped mean requires a time dimension!")

# pylint: disable=import-outside-toplevel
from .ops.stats import mean_grp

if nodata is None:
nodata = self._obj.attrs.get("nodata", None)

if nodata is None:
raise ValueError("Need to define nodata value!")

groups = (
np.array(groups, dtype="int16")
if not isinstance(groups, np.ndarray)
else groups
)

if groups.size != self._obj.time.size:
raise ValueError("Need array of groups same length as time dimension!")

num_groups = np.unique(groups).size

return xarray.apply_ufunc(
mean_grp,
self._obj,
groups,
num_groups,
nodata,
input_core_dims=[["time"], ["grps"], [], []],
output_core_dims=[["time"]],
keep_attrs=True,
dask="parallelized",
dask_gufunc_kwargs={"meta": self._obj.data},
)


class RollingWindowAlgos(AccessorBase):
"""Class to calculate rolling window algos on dimenson."""

def sum(
self,
window_size: int,
dtype: str = "float32",
dimension: str = "time",
nodata: Optional[Union[int, float]] = None,
):
# pylint: disable=import-outside-toplevel
from .ops.stats import rolling_sum

if nodata is None:
if (nodata := self._obj.attrs.get("nodata")) is None:
raise ValueError(
"Need nodata attribute defined, or nodata argument provided."
)

xx = xarray.apply_ufunc(
rolling_sum,
self._obj,
window_size,
nodata,
input_core_dims=[[dimension], [], []],
output_core_dims=[[dimension]],
keep_attrs=True,
dask="parallelized",
dask_gufunc_kwargs={"meta": self._obj.astype(dtype).data},
)
xx = xx[..., window_size - 1 :]
return xx


class ZonalStatistics(AccessorBase):
"""Class to claculate zonal statistics."""
Expand All @@ -614,7 +764,7 @@ def mean(
self,
zones: xarray.DataArray,
zone_ids: Union[List, np.ndarray],
dtype: str = "float64",
dtype: str = "float32",
dim_name: str = "zones",
name: Optional[str] = None,
) -> xarray.DataArray:
Expand Down Expand Up @@ -660,6 +810,9 @@ def mean(
"stat": ["mean", "valid"],
}

# convert str datatype to type
dtype = np.dtype(dtype).type

if is_dask_collection(xx):
dask_name = name
if isinstance(dask_name, str):
Expand All @@ -677,7 +830,7 @@ def mean(
drop_axis=[1, 2],
new_axis=[1, 2],
chunks=chunks,
dtype=dtype,
out_dtype=dtype,
name=dask_name,
)
else:
Expand All @@ -687,6 +840,7 @@ def mean(
num_zones,
xx.nodata,
zones.nodata,
out_dtype=dtype,
)

return xarray.DataArray(
Expand All @@ -703,5 +857,6 @@ def __init__(self, xarray_obj):
self.algo = PixelAlgorithms(xarray_obj)
self.anom = Anomalies(xarray_obj)
self.iteragg = IterativeAggregation(xarray_obj)
self.rolling = RollingWindowAlgos(xarray_obj)
self.whit = WhittakerSmoother(xarray_obj)
self.zonal = ZonalStatistics(xarray_obj)
Loading

0 comments on commit 15c0297

Please sign in to comment.