Skip to content

Commit

Permalink
Refactor base indicator classes (#1446)
Browse files Browse the repository at this point in the history
<!--Please ensure the PR fulfills the following requirements! -->
<!-- If this is your first PR, make sure to add your details to the
AUTHORS.rst! -->
### Pull Request Checklist:
- [x] This PR addresses an already opened issue (for bug fixes /
features)
    - This PR fixes #1263
- [ ] Tests for the changes have been added (for bug fixes / features)
- [x] (If applicable) Documentation has been added / updated (for bug
fixes / features)
- [x] CHANGES.rst has been updated (with summary of main changes)
- [x] Link to issue (:issue:`number`) and pull request (:pull:`number`)
has been added

### What kind of change does this PR introduce?

Move around some of the base indicator classes code. Split the missing
values handling from the resampling handling. This way, indicators like
`return_level` can reduce the full time axis and still perform some
missing values handling.

### Does this PR introduce a breaking change?
No, not yet.


### Other information:
@huard @RondeauG, if I am not mistaken, `return_level` and `fit`
currently have missing values handling disabled because it was
impossible with the previous classes to have both this and full
reduction of "time". Now that it is possible, should we activate it ?

Similarly, `stats` had the "Any" missing method forced. Is there a
reason for that ? Here, I removed the argument, meaning it will use
"from_context".
  • Loading branch information
aulemahal authored Aug 31, 2023
2 parents 36dcbdf + 71c52a7 commit 2490fcc
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 57 deletions.
8 changes: 6 additions & 2 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@ New features and enhancements
* Added new function ``xclim.sdba.properties.std`` to calculate the standard deviation of a variable over all years at a given time resolution. (:pull:`1445`).
* Amended the documentation of ``xclim.sdba.properties.trend`` to document already existing functionality of calculating the return values of ``scipy.stats.linregress``. (:pull:`1445`).
* Add support for setting optional variables through the `ds` argument. (:issue:`1432`, :pull:`1435`).
* New ``xclim.core.calendar.is_offset_divisor`` to test if a given freq divides another one evenly (:pull:`1446`).
* Missing value objects now support input timeseries of quarterly and yearly frequencies (:pull:`1446`).
* Missing value checks enabled for all "generic" indicators (``return_level``, ``fit`` and ``stats``) (:pull:`1446`).

Bug fixes
^^^^^^^^^
* Fix `kldiv` docstring so the math formula renders to HTML. (:issue:`1408`, :pull:`1409`).
* Fix ``kldiv`` docstring so the math formula renders to HTML. (:issue:`1408`, :pull:`1409`).
* Fix the registry entries of "generic" indicators. (:issue:`1423`, :pull:`1424`).
* Fix `jetstream_metric_woollings` so it uses the `vertical` coordinate identified by `cf-xarray`, instead of `pressure`. (:issue:`1421`, :pull:`1422`). Add logic to handle coordinates in decreasing order, or for longitudes defined from 0-360 instead of -180 to 180. (:issue:`1429`, :pull:`1430`).
* Fix ``jetstream_metric_woollings`` so it uses the `vertical` coordinate identified by `cf-xarray`, instead of `pressure`. (:issue:`1421`, :pull:`1422`). Add logic to handle coordinates in decreasing order, or for longitudes defined from 0-360 instead of -180 to 180. (:issue:`1429`, :pull:`1430`).
* Fix virtual indicator attribute assignment causing individual indicator's realm to be ignored. (:issue:`1425`, :pull:`1426`).
* Fixes the `raise_flags` argument of ``xclim.core.dataflags.data_flags`` so that an Exception is only raised when some checkups fail (:issue:`1456`, :pull:`1457`).
* Fix ``xclim.indices.generic.get_zones`` so that `bins` can be given as input without error. (:pull:`1455`).
Expand All @@ -49,6 +52,7 @@ Internal changes
* Added a helper module ``_finder`` in the notebooks folder so that the working directory can always be found, with redundancies in place to prevent scripts from failing if the helper file is not found. (:pull:`1449`).
* Added a manual cache-cleaning workflow (based on `GitHub cache-cleaning example <https://docs.github.com/en/actions/using-workflows/caching-dependencies-to-speed-up-workflows#managing-caches>`_), triggered when a branch has been merged. (:pull:`1462`).
* Added a workflow for posting updates to the xclim Mastodon account (using `cbrgm/mastodon-github-action <https://github.com/cbrgm/mastodon-github-action>`_, triggered when a new version is published. (:pull:`1462`).
* Refactor base indicator classes and fix misleading inheritance of ``return_level`` (:issue:`1263`, :pull:`1446`).

Breaking changes
^^^^^^^^^^^^^^^^
Expand Down
14 changes: 11 additions & 3 deletions tests/test_generic_indicators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,18 @@ def test_simple(self, pr_ndseries, random):
ts = generic.stats(pr, freq="YS", op="max")
p = generic.fit(ts, dist="gumbel_r")
assert p.attrs["estimator"] == "Maximum likelihood"
assert "time" not in p.dims

def test_nan(self, pr_series, random):
r = random.random(22)
r[0] = np.nan
pr = pr_series(r)

out = generic.fit(pr, dist="norm")
assert not np.isnan(out.values[0])
assert np.isnan(out.values[0])
with set_options(check_missing="skip"):
out = generic.fit(pr, dist="norm")
assert not np.isnan(out.values[0])

def test_ndim(self, pr_ndseries, random):
pr = pr_ndseries(random.random((100, 1, 2)))
Expand All @@ -28,6 +32,9 @@ def test_ndim(self, pr_ndseries, random):

def test_options(self, q_series, random):
q = q_series(random.random(19))
out = generic.fit(q, dist="norm")
np.testing.assert_array_equal(out.isnull(), False)

with set_options(missing_options={"at_least_n": {"n": 10}}):
out = generic.fit(q, dist="norm")
np.testing.assert_array_equal(out.isnull(), False)
Expand Down Expand Up @@ -87,8 +94,9 @@ def test_ndq(self, ndq_series):
assert out.attrs["units"] == "m3 s-1"

def test_missing(self, ndq_series):
a = ndq_series
a = ndq_series.where(~((a.time.dt.dayofyear == 5) * (a.time.dt.year == 1902)))
a = ndq_series.where(
~((ndq_series.time.dt.dayofyear == 5) & (ndq_series.time.dt.year == 1902))
)
assert a.shape == (5000, 2, 3)
out = generic.stats(a, op="max", month=1)

Expand Down
22 changes: 22 additions & 0 deletions tests/test_missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,21 @@ def test_monthly_input(self, random):
mb = missing.MissingBase(ts, freq="AS", src_timestep="M", season="JJA")
assert mb.count == 3

def test_seasonal_input(self, random):
"""Creating array with 11 seasons."""
n = 11
time = xr.cftime_range(start="2002-04-01", periods=n, freq="QS-JAN")
ts = xr.DataArray(random.random(n), dims="time", coords={"time": time})
mb = missing.MissingBase(ts, freq="YS", src_timestep="QS-JAN")
# Make sure count is 12, because we're requesting a YS freq.
np.testing.assert_array_equal(mb.count, [4, 4, 4, 1])

with pytest.raises(
NotImplementedError,
match="frequency that is not aligned with the source timestep.",
):
missing.MissingBase(ts, freq="YS", src_timestep="QS-DEC")


class TestMissingAnyFills:
def test_missing_days(self, tas_series):
Expand Down Expand Up @@ -144,6 +159,13 @@ def test_hourly(self, pr_hr_series):
out = missing.missing_any(pr, freq="MS")
np.testing.assert_array_equal(out, [True, False, True])

def test_seasonal(self, random):
n = 11
time = xr.cftime_range(start="2002-01-01", periods=n, freq="QS-JAN")
ts = xr.DataArray(random.random(n), dims="time", coords={"time": time})
out = missing.missing_any(ts, freq="YS")
np.testing.assert_array_equal(out, [False, False, True])


class TestMissingWMO:
def test_missing_days(self, tas_series):
Expand Down
53 changes: 53 additions & 0 deletions xclim/core/calendar.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"ensure_cftime_array",
"get_calendar",
"interp_calendar",
"is_offset_divisor",
"max_doy",
"parse_offset",
"percentile_doy",
Expand Down Expand Up @@ -841,6 +842,58 @@ def construct_offset(mult: int, base: str, start_anchored: bool, anchor: str | N
)


def is_offset_divisor(divisor: str, offset: str):
"""Check that divisor is a divisor of offset.
A frequency is a "divisor" of another if a whole number of periods of the
former fit within a single period of the latter.
Parameters
----------
divisor: str
The divisor frequency.
offset: str
The large frequency.
Returns
-------
bool
Examples
--------
>>> is_offset_divisor("QS-Jan", "YS")
True
>>> is_offset_divisor("QS-DEC", "AS-JUL")
False
>>> is_offset_divisor("D", "M")
True
"""
if compare_offsets(divisor, ">", offset):
return False
# Reconstruct offsets anchored at the start of the period
# to have comparable quantities, also get "offset" objects
mA, bA, sA, aA = parse_offset(divisor)
offAs = pd.tseries.frequencies.to_offset(construct_offset(mA, bA, True, aA))

mB, bB, sB, aB = parse_offset(offset)
offBs = pd.tseries.frequencies.to_offset(construct_offset(mB, bB, True, aB))
tB = pd.date_range("1970-01-01T00:00:00", freq=offBs, periods=13)

if bA in "WDHTLUN" or bB in "WDHTLUN":
# Simple length comparison is sufficient for submonthly freqs
# In case one of bA or bB is > W, we test many to be sure.
tA = pd.date_range("1970-01-01T00:00:00", freq=offAs, periods=13)
return np.all(
(np.diff(tB)[:, np.newaxis] / np.diff(tA)[np.newaxis, :]) % 1 == 0
)

# else, we test alignment with some real dates
# If both fall on offAs, then is means divisor is aligned with offset at those dates
# if N=13 is True, then it is always True
# As divisor <= offset, this means divisor is a "divisor" of offset.
return all(offAs.is_on_offset(d) for d in tB)


def _interpolate_doy_calendar(
source: xr.DataArray, doy_max: int, doy_min: int = 1
) -> xr.DataArray:
Expand Down
140 changes: 98 additions & 42 deletions xclim/core/indicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1353,11 +1353,13 @@ def _show_deprecation_warning(self):
)


class ResamplingIndicator(Indicator):
"""Indicator that performs a resampling computation.
class CheckMissingIndicator(Indicator):
"""Class adding missing value checks to indicators.
Compared to the base Indicator, this adds the handling of missing data,
and the check of allowed periods.
This should not be used as-is, but subclassed by implementing the `_get_missing_freq` method.
This method will be called in `_postprocess` using the compute parameters as only argument.
It should return a freq string, the same as the output freq of the computed data.
It can also be "None" to indicator the full time axis has been reduced, or "False" to skip the missing checks.
Parameters
----------
Expand All @@ -1366,24 +1368,10 @@ class ResamplingIndicator(Indicator):
None, this will be determined by the global configuration (see `xclim.set_options`). Defaults to "from_context".
missing_options : dict, optional
Arguments to pass to the `missing` function. If None, this will be determined by the global configuration.
allowed_periods : Sequence[str], optional
A list of allowed periods, i.e. base parts of the `freq` parameter. For example, indicators meant to be
computed annually only will have `allowed_periods=["A"]`. `None` means "any period" or that the
indicator doesn't take a `freq` argument.
"""

missing = "from_context"
missing_options: dict | None = None
allowed_periods: list[str] | None = None

@classmethod
def _ensure_correct_parameters(cls, parameters):
if "freq" not in parameters:
raise ValueError(
"ResamplingIndicator require a 'freq' argument, use the base Indicator"
" class if your computation doesn't perform any resampling."
)
return super()._ensure_correct_parameters(parameters)

def __init__(self, **kwds):
if self.missing == "from_context" and self.missing_options is not None:
Expand All @@ -1399,23 +1387,6 @@ def __init__(self, **kwds):

super().__init__(**kwds)

def _preprocess_and_checks(self, das, params):
"""Perform parent's checks and also check if freq is allowed."""
das, params = super()._preprocess_and_checks(das, params)

# Check if the period is allowed:
if (
self.allowed_periods is not None
and parse_offset(params["freq"])[1] not in self.allowed_periods
):
raise ValueError(
f"Resampling frequency {params['freq']} is not allowed for indicator "
f"{self.identifier} (needs something equivalent to one "
f"of {self.allowed_periods})."
)

return das, params

def _history_string(self, **kwargs):
if self.missing == "from_context":
missing = OPTIONS[CHECK_MISSING]
Expand All @@ -1430,11 +1401,16 @@ def _history_string(self, **kwargs):

return super()._history_string(**kwargs) + opt_str

def _get_missing_freq(self, params):
"""Return the resampling frequency to be used in the missing values check."""
raise NotImplementedError("Don't use `CheckMissingIndicator` directly.")

def _postprocess(self, outs, das, params):
"""Masking of missing values."""
outs = super()._postprocess(outs, das, params)

if self.missing != "skip":
freq = self._get_missing_freq(params)
if self.missing != "skip" or freq is False:
# Mask results that do not meet criteria defined by the `missing` method.
# This means all outputs must have the same dimensions as the broadcasted inputs (excluding time)
options = self.missing_options or OPTIONS[MISSING_OPTIONS].get(
Expand All @@ -1444,24 +1420,96 @@ def _postprocess(self, outs, das, params):
# We flag periods according to the missing method. skip variables without a time coordinate.
src_freq = self.src_freq if isinstance(self.src_freq, str) else None
miss = (
self._missing(
da, params["freq"], src_freq, options, params.get("indexer", {})
)
self._missing(da, freq, src_freq, options, params.get("indexer", {}))
for da in das.values()
if "time" in da.coords
)
# Reduce by or and broadcast to ensure the same length in time
# When indexing is used and there are no valid points in the last period, mask will not include it
mask = reduce(np.logical_or, miss)
if isinstance(mask, DataArray) and mask.time.size < outs[0].time.size:
if (
isinstance(mask, DataArray)
and "time" in mask.dims
and mask.time.size < outs[0].time.size
):
mask = mask.reindex(time=outs[0].time, fill_value=True)
outs = [out.where(~mask) for out in outs]

return outs


class ResamplingIndicatorWithIndexing(ResamplingIndicator):
"""Resampling indicator that also injects "indexer" kwargs to subset the inputs before computation."""
class ReducingIndicator(CheckMissingIndicator):
"""Indicator that performs a time-reducing computation.
Compared to the base Indicator, this adds the handling of missing data.
Parameters
----------
missing : {any, wmo, pct, at_least_n, skip, from_context}
The name of the missing value method. See `xclim.core.missing.MissingBase` to create new custom methods. If
None, this will be determined by the global configuration (see `xclim.set_options`). Defaults to "from_context".
missing_options : dict, optional
Arguments to pass to the `missing` function. If None, this will be determined by the global configuration.
"""

def _get_missing_freq(self, params):
"""Return None, to indicate that the full time axis is to be reduced."""
return None


class ResamplingIndicator(CheckMissingIndicator):
"""Indicator that performs a resampling computation.
Compared to the base Indicator, this adds the handling of missing data,
and the check of allowed periods.
Parameters
----------
missing : {any, wmo, pct, at_least_n, skip, from_context}
The name of the missing value method. See `xclim.core.missing.MissingBase` to create new custom methods. If
None, this will be determined by the global configuration (see `xclim.set_options`). Defaults to "from_context".
missing_options : dict, optional
Arguments to pass to the `missing` function. If None, this will be determined by the global configuration.
allowed_periods : Sequence[str], optional
A list of allowed periods, i.e. base parts of the `freq` parameter. For example, indicators meant to be
computed annually only will have `allowed_periods=["A"]`. `None` means "any period" or that the
indicator doesn't take a `freq` argument.
"""

allowed_periods: list[str] | None = None

@classmethod
def _ensure_correct_parameters(cls, parameters):
if "freq" not in parameters:
raise ValueError(
"ResamplingIndicator require a 'freq' argument, use the base Indicator"
" class if your computation doesn't perform any resampling."
)
return super()._ensure_correct_parameters(parameters)

def _get_missing_freq(self, params):
return params["freq"]

def _preprocess_and_checks(self, das, params):
"""Perform parent's checks and also check if freq is allowed."""
das, params = super()._preprocess_and_checks(das, params)

# Check if the period is allowed:
if (
self.allowed_periods is not None
and parse_offset(params["freq"])[1] not in self.allowed_periods
):
raise ValueError(
f"Resampling frequency {params['freq']} is not allowed for indicator "
f"{self.identifier} (needs something equivalent to one "
f"of {self.allowed_periods})."
)

return das, params


class IndexingIndicator(Indicator):
"""Indicator that also injects "indexer" kwargs to subset the inputs before computation."""

@classmethod
def _injected_parameters(cls):
Expand Down Expand Up @@ -1490,6 +1538,12 @@ def _preprocess_and_checks(self, das: dict[str, DataArray], params: dict[str, An
return das, params


class ResamplingIndicatorWithIndexing(ResamplingIndicator, IndexingIndicator):
"""Resampling indicator that also injects "indexer" kwargs to subset the inputs before computation."""

pass


class Daily(ResamplingIndicator):
"""Class for daily inputs and resampling computes."""

Expand All @@ -1503,6 +1557,8 @@ class Hourly(ResamplingIndicator):


base_registry["Indicator"] = Indicator
base_registry["ReducingIndicator"] = ReducingIndicator
base_registry["IndexingIndicator"] = IndexingIndicator
base_registry["ResamplingIndicator"] = ResamplingIndicator
base_registry["ResamplingIndicatorWithIndexing"] = ResamplingIndicatorWithIndexing
base_registry["Hourly"] = Hourly
Expand Down
Loading

0 comments on commit 2490fcc

Please sign in to comment.