-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix & improve grouped spi #50
Changes from 8 commits
d9097ee
b194183
bb74625
6d62a0d
f80bf11
367cfeb
9e81abd
7bbd550
c75efc8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
"""Version only in this file.""" | ||
|
||
__version__ = "0.4.0" | ||
__version__ = "0.5.0" |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,11 +7,11 @@ | |
import dask.array as da | ||
from dask.base import tokenize | ||
import numpy as np | ||
import pandas as pd | ||
import xarray | ||
|
||
from . import ops | ||
from .dekad import Dekad | ||
from .utils import get_calibration_indices, to_linspace | ||
|
||
__all__ = [ | ||
"Anomalies", | ||
|
@@ -466,7 +466,7 @@ def spi( | |
calibration_start: Optional[str] = None, | ||
calibration_stop: Optional[str] = None, | ||
nodata: Optional[Union[float, int]] = None, | ||
groups: Optional[Iterable[int]] = None, | ||
groups: Optional[Iterable[Union[int, float, str]]] = None, | ||
dtype="int16", | ||
): | ||
"""Calculate the SPI along the time dimension. | ||
|
@@ -475,8 +475,7 @@ def spi( | |
Optionally, a calibration start and / or stop date can be provided which | ||
determine the part of the timeseries used to fit the gamma distribution. | ||
|
||
`groups` can be supplied as list of group labels (attention, they are required | ||
to be in format {0..n-1} where n is the number of unique groups. | ||
`groups` can be supplied as list of group labels. | ||
If `groups` is supplied, the SPI will be computed for each individual group. | ||
This is intended to be used when SPI should be calculated for specific timesteps. | ||
""" | ||
|
@@ -497,33 +496,31 @@ def spi( | |
|
||
tix = self._obj.get_index("time") | ||
|
||
calstart_ix = 0 | ||
if calibration_start is not None: | ||
calstart = pd.Timestamp(calibration_start) | ||
if calstart > tix[-1]: | ||
raise ValueError( | ||
"Calibration start cannot be greater than last timestamp!" | ||
) | ||
(calstart_ix,) = tix.get_indexer([calstart], method="bfill") | ||
if calibration_start is None: | ||
calibration_start = tix[0] | ||
|
||
calstop_ix = tix.size | ||
if calibration_stop is not None: | ||
calstop = pd.Timestamp(calibration_stop) | ||
if calstop < tix[0]: | ||
raise ValueError( | ||
"Calibration stop cannot be smaller than first timestamp!" | ||
) | ||
(calstop_ix,) = tix.get_indexer([calstop], method="ffill") + 1 | ||
if calibration_stop is None: | ||
calibration_stop = tix[-1] | ||
|
||
if calstart_ix >= calstop_ix: | ||
raise ValueError("calibration_start < calibration_stop!") | ||
if calibration_start > tix[-1:]: | ||
raise ValueError("Calibration start cannot be greater than last timestamp!") | ||
|
||
if abs(calstop_ix - calstart_ix) <= 1: | ||
raise ValueError( | ||
"Timeseries too short for calculating SPI. Please adjust calibration period!" | ||
) | ||
if calibration_stop < tix[:1]: | ||
raise ValueError("Calibration stop cannot be smaller than first timestamp!") | ||
|
||
if groups is None: | ||
calstart_ix, calstop_ix = get_calibration_indices( | ||
tix, calibration_start, calibration_stop | ||
) | ||
|
||
if calstart_ix >= calstop_ix: | ||
raise ValueError("calibration_start < calibration_stop!") | ||
|
||
if abs(calstop_ix - calstart_ix) <= 1: | ||
raise ValueError( | ||
"Timeseries too short for calculating SPI. Please adjust calibration period!" | ||
) | ||
|
||
res = xarray.apply_ufunc( | ||
gammastd_yxt, | ||
self._obj, | ||
|
@@ -540,22 +537,37 @@ def spi( | |
) | ||
|
||
else: | ||
groups = np.array(groups) if not isinstance(groups, np.ndarray) else groups | ||
num_groups = np.unique(groups).size | ||
|
||
if not groups.dtype.name == "int16": | ||
warn("Casting groups to int16!") | ||
groups = groups.astype("int16") | ||
groups, keys = to_linspace(np.array(groups, dtype="str")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Kirill888 the idea is that the user passes in a list / array / tuple / ... containing group labels which puts the To take this burden off the user, this allows any labels and converts them to a For a practical example, we'll be using this to calculate SPIs for the full timeseries, grouping the time dimension into dekads per year. The user can then pass simply |
||
|
||
if len(groups) != len(self._obj.time): | ||
raise ValueError("Need array of groups same length as time dimension!") | ||
|
||
groups = groups.astype("int16") | ||
num_groups = len(keys) | ||
|
||
cal_indices = get_calibration_indices( | ||
tix, calibration_start, calibration_stop, groups, num_groups | ||
) | ||
valpesendorfer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# assert for mypy | ||
assert isinstance(cal_indices, np.ndarray) | ||
|
||
if np.any(cal_indices[:, 0] >= cal_indices[:, 1]): | ||
raise ValueError("calibration_start < calibration_stop!") | ||
|
||
if np.any(np.diff(cal_indices, axis=1) <= 1): | ||
raise ValueError( | ||
"Timeseries too short for calculating SPI. Please adjust calibration period!" | ||
) | ||
|
||
res = xarray.apply_ufunc( | ||
gammastd_grp, | ||
self._obj, | ||
groups, | ||
num_groups, | ||
nodata, | ||
calstart_ix, | ||
calstop_ix, | ||
input_core_dims=[["time"], ["grps"], [], [], [], []], | ||
cal_indices, | ||
input_core_dims=[["time"], ["grps"], [], [], ["start", "stop"]], | ||
output_core_dims=[["time"]], | ||
keep_attrs=True, | ||
dask="parallelized", | ||
|
@@ -564,8 +576,8 @@ def spi( | |
|
||
res.attrs.update( | ||
{ | ||
"spi_calibration_start": str(tix[calstart_ix].date()), | ||
"spi_calibration_stop": str(tix[calstop_ix - 1].date()), | ||
"spi_calibration_start": str(tix[tix >= calibration_start][0]), | ||
"spi_calibration_stop": str(tix[tix <= calibration_stop][-1]), | ||
} | ||
) | ||
|
||
|
@@ -806,7 +818,7 @@ def mean( | |
|
||
# set null values to nodata value | ||
xx = xx.where(xx.notnull(), xx.nodata) | ||
|
||
attrs = xx.attrs | ||
num_zones = len(zone_ids) | ||
dims = (xx.dims[0], dim_name, "stat") | ||
coords = { | ||
|
@@ -849,7 +861,7 @@ def mean( | |
) | ||
|
||
return xarray.DataArray( | ||
data=data, dims=dims, coords=coords, attrs={}, name=name | ||
data=data, dims=dims, coords=coords, attrs=attrs, name=name | ||
) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this comparing
str > [str]
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comparing
str
topd.Timestamp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why
tix[-1:]
and nottix[-1]
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
needs to remain a
pd.DateTimeIndex
with one element to work ... maybe bit hacky & counterintuitiveThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could explicitly convert the string to a date object (
pd.Timestamp
or whatever) and then compare the elements if that's clearer