diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 281d721e1..b28b0e176 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -4,7 +4,16 @@ Changelog v0.53.0 (unreleased) -------------------- -Contributors to this version: Adrien Lamarche (:user:`LamAdr`), Trevor James Smith (:user:`Zeitsperre`), Éric Dupuis (:user:`coxipi`). +Contributors to this version: Adrien Lamarche (:user:`LamAdr`), Trevor James Smith (:user:`Zeitsperre`), Éric Dupuis (:user:`coxipi`), Pascal Bourgault (:user:`aulemahal`). + +New indicators +^^^^^^^^^^^^^^ +* New ``heat_spell_frequency``, ``heat_spell_max_length`` and ``heat_spell_total_length`` : spell length statistics on a bivariate condition that uses the average over a window by default. (:pull:`1885`). + +New features and enhancements +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +* New generic ``xclim.indices.generic.spell_mask`` that returns a mask of which days are part of a spell. Supports multivariate conditions and weights. Used in new generic index ``xclim.indices.generic.bivariate_spell_length_statistics`` that extends ``spell_length_statistics`` to two variables. (:pull:`1885`). +* Indicator parameters can now be assigned a new name, different from the argument name in the compute function. (:pull:`1885`). Bug fixes ^^^^^^^^^ @@ -32,6 +41,7 @@ Internal changes * Many ``DeprecationWarning`` and ``FutureWarning`` messages emitted from `xarray` and `pint` have been addressed. (:issue:`1719`, :pull:`1881`). * The codebase has been adjusted to address many `pylint`-related warnings and errors. In some cases, `casting` was used to redefine some `numpy` and `xarray` objects. (:issue:`1719`, :pull:`1881`). * ``xclim.core`` now uses absolute imports for clarity and some objects commonly used in the module have been moved to hidden submodules. (:issue:`1719`, :pull:`1881`). +* ``xclim.core.indicator.Parameter`` has a new attribute ``compute_name`` while ``xclim.core.indicator.Indicator`` lost its ``_variable_mapping``. The translation from parameter (and variable) names in the indicator to the names on the compute function is handled by ``Indicator._get_compute_args``. (:pull:`1885`). v0.52.0 (2024-08-08) -------------------- diff --git a/tests/test_formatting.py b/tests/test_formatting.py index 98ffb2563..7f58c2a67 100644 --- a/tests/test_formatting.py +++ b/tests/test_formatting.py @@ -35,7 +35,7 @@ def test_indicator_docstring(): assert ( doc[41] == " Total number of series of at least {window} consecutive days with daily minimum temperature above " - "{thresh_tasmin} and daily maximum temperature above {thresh_tasmax} (heat_wave_events), " + "{thresh_tasmin} and daily maximum temperature above {thresh_tasmax}, " "with additional attributes: **description**: {freq} number of heat wave events within a given period. " "A heat wave occurs when daily minimum and maximum temperatures exceed {thresh_tasmin} and {thresh_tasmax}, " "respectively, over at least {window} days." diff --git a/tests/test_generic.py b/tests/test_generic.py index 8448a5607..0054dd935 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -663,3 +663,108 @@ def test_select_time_errors(self): with pytest.raises(TypeError): select_time(da, doy_bounds=(300, 203, 202)) + + +class TestSpellMask: + def test_single_variable(self): + data = xr.DataArray([0, 1, 2, 3, 2, 1, 0, 0], dims=("time",)) + + out = generic.spell_mask(data, 3, "min", ">=", 2) + np.testing.assert_array_equal( + out, np.array([0, 0, 1, 1, 1, 0, 0, 0]).astype(bool) + ) + + out = generic.spell_mask(data, 3, "max", ">=", 2) + np.testing.assert_array_equal( + out, np.array([1, 1, 1, 1, 1, 1, 1, 0]).astype(bool) + ) + + out = generic.spell_mask(data, 2, "mean", ">=", 2) + np.testing.assert_array_equal( + out, np.array([0, 0, 1, 1, 1, 0, 0, 0]).astype(bool) + ) + + out = generic.spell_mask(data, 3, "mean", ">", 2, weights=[0.2, 0.4, 0.4]) + np.testing.assert_array_equal( + out, np.array([0, 1, 1, 1, 1, 0, 0, 0]).astype(bool) + ) + + def test_multiple_variables(self): + data1 = xr.DataArray([0, 1, 2, 3, 2, 1, 0, 0], dims=("time",)) + data2 = xr.DataArray([1, 2, 3, 2, 1, 0, 0, 0], dims=("time",)) + + out = generic.spell_mask([data1, data2], 3, "min", ">=", [2, 2]) + np.testing.assert_array_equal( + out, np.array([0, 0, 0, 0, 0, 0, 0, 0]).astype(bool) + ) + + out = generic.spell_mask( + [data1, data2], 3, "min", ">=", [2, 2], var_reducer="any" + ) + np.testing.assert_array_equal( + out, np.array([0, 1, 1, 1, 1, 0, 0, 0]).astype(bool) + ) + + out = generic.spell_mask([data1, data2], 2, "mean", ">=", [2, 2]) + np.testing.assert_array_equal( + out, np.array([0, 0, 1, 1, 0, 0, 0, 0]).astype(bool) + ) + + out = generic.spell_mask( + [data1, data2], 3, "mean", ">", [2, 1.5], weights=[0.2, 0.4, 0.4] + ) + np.testing.assert_array_equal( + out, np.array([0, 1, 1, 1, 1, 0, 0, 0]).astype(bool) + ) + + def test_errors(self): + data = xr.DataArray([0, 1, 2, 3, 2, 1, 0, 0], dims=("time",)) + + # Threshold must be seq + with pytest.raises(ValueError, match="must be a sequence of the same length"): + generic.spell_mask([data, data], 3, "min", "<=", 2) + + # Threshold must be same length + with pytest.raises(ValueError, match="must be a sequence of the same length"): + generic.spell_mask([data, data], 3, "min", "<=", [2]) + + # Weights must have win_reducer = 'mean' + with pytest.raises( + ValueError, match="is only supported if 'win_reducer' is 'mean'" + ): + generic.spell_mask(data, 3, "min", "<=", 2, weights=[1, 2, 3]) + + # Weights must have same length as window + with pytest.raises(ValueError, match="Weights have a different length"): + generic.spell_mask(data, 3, "mean", "<=", 2, weights=[1, 2]) + + +def test_spell_length_statistics_multi(tasmin_series, tasmax_series): + tn = tasmin_series( + np.zeros( + 365, + ) + + 270, + start="2001-01-01", + ) + tx = tasmax_series( + np.zeros( + 365, + ) + + 270, + start="2001-01-01", + ) + + outc, outs, outm = generic.bivariate_spell_length_statistics( + tn, + "0 °C", + tx, + "1°C", + window=5, + win_reducer="min", + op="<", + spell_reducer=["count", "sum", "max"], + freq="YS", + ) + xr.testing.assert_equal(outs, outm) + np.testing.assert_allclose(outc, 1) diff --git a/tests/test_indicators.py b/tests/test_indicators.py index fa9ca2cdc..29aa1a988 100644 --- a/tests/test_indicators.py +++ b/tests/test_indicators.py @@ -709,8 +709,7 @@ def test_indicator_from_dict(): assert ind.parameters["threshold"].description == "A threshold temp" # Injection of parameters assert ind.injected_parameters["op"] == "<" - # Default value for input variable injected and meta injected - assert ind._variable_mapping["data"] == "tas" + assert ind.parameters["tas"].compute_name == "data" assert signature(ind).parameters["tas"].default == "tas" assert ind.parameters["tas"].units == "[temperature]" diff --git a/tests/test_temperature.py b/tests/test_temperature.py index 23ed23a0d..45f3c94b7 100644 --- a/tests/test_temperature.py +++ b/tests/test_temperature.py @@ -610,6 +610,100 @@ def test_3d_data_with_nans(self, open_dataset): assert np.isnan(gdd.values[0, -1, -1]) +class TestHeatSpellFrequency: + def test_1d(self, tasmax_series, tasmin_series): + tn1 = np.zeros(366) + tx1 = np.zeros(366) + tn1[:10] = np.array([20, 23, 23, 23, 20, 20, 23, 23, 23, 23]) + tx1[:10] = np.array([29, 31, 31, 31, 28, 28, 31, 31, 31, 31]) + + tn = tasmin_series(tn1 + K2C, start="1/1/2000") + tx = tasmax_series(tx1 + K2C, start="1/1/2000") + + hsf = atmos.heat_spell_frequency( + tn, + tx, + thresh_tasmin="22.1 C", + thresh_tasmax="30.1 C", + freq="YS", + ) + np.testing.assert_allclose(hsf.values[:1], 2) + + hsf = atmos.heat_spell_frequency( + tn, tx, thresh_tasmin="22 C", thresh_tasmax="30 C", window=5, freq="YS" + ) + np.testing.assert_allclose(hsf.values[:1], 1) + + # no hs + hsf = atmos.heat_spell_frequency( + tn, tx, thresh_tasmin="40 C", thresh_tasmax="40 C", freq="YS" + ) + np.testing.assert_allclose(hsf.values[:1], 0) + + +class TestHeatSpellMaxLength: + def test_1d(self, tasmax_series, tasmin_series): + tn1 = np.zeros(366) + tx1 = np.zeros(366) + tn1[:10] = np.array([20, 23, 23, 23, 20, 20, 23, 23, 23, 23]) + tx1[:10] = np.array([29, 31, 31, 31, 28, 28, 31, 31, 31, 31]) + + tn = tasmin_series(tn1 + K2C, start="1/1/2000") + tx = tasmax_series(tx1 + K2C, start="1/1/2000") + + hsf = atmos.heat_spell_max_length( + tn, + tx, + thresh_tasmin="22.1 C", + thresh_tasmax="30.1 C", + freq="YS", + ) + np.testing.assert_allclose(hsf.values[:1], 4) + + hsf = atmos.heat_spell_max_length( + tn, + tx, + thresh_tasmin="22 C", + thresh_tasmax="30 C", + window=5, + freq="YS", + ) + np.testing.assert_allclose(hsf.values[:1], 5) + + # no hs + hsf = atmos.heat_spell_max_length( + tn, tx, thresh_tasmin="40 C", thresh_tasmax="40 C", freq="YS" + ) + np.testing.assert_allclose(hsf.values[:1], 0) + + +class TestHeatSpellTotalLength: + def test_1d(self, tasmax_series, tasmin_series): + tn1 = np.zeros(366) + tx1 = np.zeros(366) + tn1[:10] = np.array([20, 23, 23, 23, 20, 20, 23, 23, 23, 23]) + tx1[:10] = np.array([29, 31, 31, 31, 28, 28, 31, 31, 31, 31]) + + tn = tasmin_series(tn1 + K2C, start="1/1/2000") + tx = tasmax_series(tx1 + K2C, start="1/1/2000") + + hsf = atmos.heat_spell_total_length( + tn, tx, thresh_tasmin="22.1 C", thresh_tasmax="30.1 C", freq="YS" + ) + np.testing.assert_allclose(hsf.values[:1], 7) + + hsf = atmos.heat_spell_total_length( + tn, tx, thresh_tasmin="22 C", thresh_tasmax="30 C", window=5, freq="YS" + ) + np.testing.assert_allclose(hsf.values[:1], 5) + + # no hs + hsf = atmos.heat_spell_total_length( + tn, tx, thresh_tasmin="40 C", thresh_tasmax="40 C", freq="YS" + ) + np.testing.assert_allclose(hsf.values[:1], 0) + + class TestHeatWaveFrequency: def test_1d(self, tasmax_series, tasmin_series): tn1 = np.zeros(366) diff --git a/xclim/core/indicator.py b/xclim/core/indicator.py index 5efb09be7..48e296e31 100644 --- a/xclim/core/indicator.py +++ b/xclim/core/indicator.py @@ -78,8 +78,9 @@ units: # Only valid if "compute" points to a generic function default : description: + name : # Change the name of the parameter (similar to what `input` does for variables) kind: # Override the parameter kind. - # This is mostly useful for transforming an optional variable into a required one by passing ``kind: 0``. + # This is mostly useful for transforming an optional variable into a required one by passing ``kind: 0``. ... ... # and so on. @@ -181,7 +182,7 @@ class _empty: # pylint: disable=too-few-public-methods class Parameter: """Class for storing an indicator's controllable parameter. - For retrocompatibility, this class implements a "getitem" and a special "contains". + For convenience, this class implements a special "contains". Example ------- @@ -198,6 +199,8 @@ class Parameter: kind: InputKind default: Any = _empty_default + # Name of the compute function's argument corresponding to this parameter. + compute_name: str = _empty description: str = "" units: str = _empty choices: set = _empty @@ -213,18 +216,13 @@ def update(self, other: dict) -> None: @classmethod def is_parameter_dict(cls, other: dict) -> bool: - """Return whether indicator has a parameter dictionary.""" + """Return whether other can update a parameter dictionary.""" + # Passing compute_name is forbidden. + # name is valid, but is handled by the indicator return set(other.keys()).issubset( - cls.__dataclass_fields__.keys() # pylint: disable=no-member + {"kind", "default", "description", "units", "choices", "value", "name"} ) - # def __getitem__(self, key) -> str: - # """Return an item in retro-compatible fashion.""" - # try: - # return str(getattr(self, key)) - # except AttributeError as err: - # raise KeyError(key) from err - def __contains__(self, key) -> bool: """Imitate previous behaviour where "units" and "choices" were missing, instead of being "_empty".""" return getattr(self, key, _empty) is not _empty @@ -353,10 +351,8 @@ class Indicator(IndicatorRegistrar): # metadata fields that are formatted as free text (first letter capitalized) _text_fields = ["long_name", "description", "comment"] - # Class attributes that are function (so we know which to convert to static methods) + # Class attributes that are functions (so we know which to convert to static methods) _funcs = ["compute"] - # Mapping from name in the compute function to official (CMIP6) variable name - _variable_mapping = {} # Will become the class's name identifier = None @@ -438,33 +434,19 @@ def __new__(cls, **kwds): # noqa: C901 else: # inherit parameters from base class parameters = deepcopy(cls._all_parameters) - # Update parameters with passed parameters + # Update parameters with passed parameters, might change some parameters name (but not variables) cls._update_parameters(parameters, kwds.pop("parameters", {})) # Input variable mapping (to change variable names in signature and expected units/cf attrs). - cls._parse_var_mapping(kwds.pop("input", {}), parameters, kwds) + # new_units is a mapping from compute function name to units inferred from the var mapping + new_units = cls._parse_var_mapping(kwds.pop("input", {}), parameters) # Raise on incorrect params, sort params, modify var defaults in-place if needed parameters = cls._ensure_correct_parameters(parameters) - # If needed, wrap compute with declare units - if "compute" in kwds: - if not hasattr(kwds["compute"], "in_units") and "_variable_mapping" in kwds: - # We actually need the inverse mapping (to get cmip6 name -> arg name) - inv_var_map = dict(map(reversed, kwds["_variable_mapping"].items())) - # parameters has already been update above. - kwds["compute"] = declare_units( - **{ - inv_var_map[k]: m.units - for k, m in parameters.items() - if "units" in m and k in inv_var_map - } - )(kwds["compute"]) - - if hasattr(kwds["compute"], "in_units"): - varmap = kwds.get("_variable_mapping", {}) - for name, unit in kwds["compute"].in_units.items(): - parameters[varmap.get(name, name)].units = unit + if "compute" in kwds and not hasattr(kwds["compute"], "in_units") and new_units: + # If needed, wrap compute with declare units + kwds["compute"] = declare_units(**new_units)(kwds["compute"]) # All updates done. kwds["_all_parameters"] = parameters @@ -535,8 +517,11 @@ def _parse_indice(compute, passed_parameters): # noqa: F841 ) for name, param in compute_sig.parameters.items(): meta = params_dict.setdefault(name, {}) + meta["compute_name"] = name meta["default"] = param.default meta["kind"] = infer_kind_from_parameter(param) + if hasattr(compute, "in_units") and name in compute.in_units: + meta["units"] = compute.in_units[name] parameters = {name: Parameter(**param) for name, param in params_dict.items()} return parameters, docmeta @@ -561,7 +546,14 @@ def _update_parameters(cls, parameters, passed): try: for key, val in passed.items(): if isinstance(val, dict) and Parameter.is_parameter_dict(val): - # modified meta + if "name" in val: + new_key = val.pop("name") + if new_key in parameters: + raise ValueError( + f"Cannot rename a parameter or variable with the same name as another parameter. '{new_key}' is already a parameter." + ) + parameters[new_key] = parameters.pop(key) + key = new_key parameters[key].update(val) elif key in parameters: parameters[key].value = val @@ -574,9 +566,10 @@ def _update_parameters(cls, parameters, passed): ) from err @classmethod - def _parse_var_mapping(cls, variable_mapping, parameters, kwds): + def _parse_var_mapping(cls, variable_mapping, parameters): """Parse the variable mapping passed in `input` and update `parameters` in-place.""" # Update parameters + new_units = {} for old_name, new_name in variable_mapping.items(): meta = parameters[new_name] = parameters.pop(old_name) try: @@ -597,13 +590,9 @@ def _parse_var_mapping(cls, variable_mapping, parameters, kwds): f"{meta.units}, new = {varmeta['canonical_units']}" ) from err meta.units = varmeta.get("dimensions", varmeta["canonical_units"]) + new_units[meta.compute_name] = meta.units meta.description = varmeta["description"] - - if variable_mapping: - # Update mapping attribute - new_variable_mapping = deepcopy(cls._variable_mapping) - new_variable_mapping.update(variable_mapping) - kwds["_variable_mapping"] = new_variable_mapping + return new_units @classmethod def _ensure_correct_parameters(cls, parameters): @@ -795,7 +784,7 @@ def _gen_signature(self): name, kind=_Parameter.KEYWORD_ONLY, default=meta.default, - annotation=compute_sig.parameters[name].annotation, + annotation=compute_sig.parameters[meta.compute_name].annotation, ) ) @@ -827,22 +816,10 @@ def __call__(self, *args, **kwds): das, params = self._preprocess_and_checks(das, params) - # Get correct variable names for the compute function. - inv_var_map = dict(map(reversed, self._variable_mapping.items())) - compute_das = {inv_var_map.get(nm, nm): das[nm] for nm in das} - - # Compute the indicator values, ignoring NaNs and missing values. - # Filter the passed parameters to only keep the ones needed by compute. - kwargs = {} - var_kwargs = {} - for nm, pa in signature(self.compute).parameters.items(): - if pa.kind == _Parameter.VAR_KEYWORD: - var_kwargs = params[nm] - elif nm not in compute_das and nm in params: - kwargs[nm] = params[nm] - + # get mappings where keys are the actual compute function's argument names + args = self._get_compute_args(das, params) with xarray.set_options(keep_attrs=False): - outs = self.compute(**compute_das, **kwargs, **var_kwargs) + outs = self.compute(**args) if isinstance(outs, DataArray): outs = [outs] @@ -928,7 +905,7 @@ def _parse_variables_from_call(self, args, kwds) -> tuple[OrderedDict, dict]: else: params[name] = param.value - ds = ba.arguments.get("ds") + ds = params.get("ds") dsattrs = ds.attrs if ds is not None else {} return das, params, dsattrs @@ -964,6 +941,24 @@ def _preprocess_and_checks(self, das, params): self._bind_call(self.cfcheck, **das) return das, params + def _get_compute_args(self, das, params): + """Rename variables and parameters to match the compute function's names and split VAR_KEYWORD arguments.""" + # Get correct variable names for the compute function. + # Exclude param without a mapping inside the compute functions (those injected by the indicator class) + args = {} + for key, p in self._all_parameters.items(): + if p.compute_name is not _empty: + if key in das: + args[p.compute_name] = das[key] + # elif because some args are in both (percentile DataArrays) + elif key in params: + if p.kind == InputKind.KWARGS: + args.update(params[key]) + else: + args[p.compute_name] = params[key] + + return args + def _postprocess(self, outs, das, params): """Actions to done after computing.""" return outs @@ -1537,21 +1532,16 @@ def _preprocess_and_checks(self, das, params): class IndexingIndicator(Indicator): """Indicator that also injects "indexer" kwargs to subset the inputs before computation.""" - @classmethod - def _injected_parameters(cls): - return super()._injected_parameters() + [ - ( - "indexer", - Parameter( - kind=InputKind.KWARGS, - description=( - "Indexing parameters to compute the indicator on a temporal " - "subset of the data. It accepts the same arguments as " - ":py:func:`xclim.indices.generic.select_time`." - ), - ), - ) - ] + def __init__(self, *args, **kwargs): + self._all_parameters["indexer"] = Parameter( + kind=InputKind.KWARGS, + description=( + "Indexing parameters to compute the indicator on a temporal " + "subset of the data. It accepts the same arguments as " + ":py:func:`xclim.indices.generic.select_time`." + ), + ) + super().__init__(*args, **kwargs) def _preprocess_and_checks(self, das: dict[str, DataArray], params: dict[str, Any]): """Perform parent's checks and also check if freq is allowed.""" diff --git a/xclim/data/fr.json b/xclim/data/fr.json index 5c395e112..d11f72253 100644 --- a/xclim/data/fr.json +++ b/xclim/data/fr.json @@ -335,6 +335,24 @@ "title": "Nombre de jours ayant des températures minimale et maximale quotidiennes au-dessus de seuils donnés", "abstract": "Nombre de jours où les températures maximale et minimale quotidiennes sont au-dessus de seuils donnés." }, + "HEAT_SPELL_FREQUENCY": { + "long_name": "Nombre de vagues de chaleur.", + "description": "Nombre {freq:m} de vagues de chaleur. Une vague de chaleur se produit lorsque les moyennes sur {window} jours des températures minimale et maximale quotidiennes excèdent {thresh_tasmin} et {thresh_tasmax}, respectivement. Les {window} jours de la moyenne sont comptés dans la vague.", + "title": "Fréquence des vagues de chaleur", + "abstract": "Nombre de vagues de chaleur. Une vague de chaleur se produit lorsque les moyennes mobiles des températures minimale et maximale quotidiennes excèdent des seuils donnés." + }, + "HEAT_SPELL_MAX_LENGTH": { + "long_name": "Vague de chaleur la plus longue.", + "description": "Durée maximale {freq:f} des vagues de chaleur. Une vague de chaleur se produit lorsque les moyennes sur {window} jours des températures minimale et maximale quotidiennes excèdent {thresh_tasmin} et {thresh_tasmax}, respectivement. Les {window} jours de la moyenne sont comptés dans la vague.", + "title": "Durée maximale des vagues de chaleur", + "abstract": "Durée maximale des vagues de chaleur. Une vague de chaleur se produit lorsque les moyennes mobiles des températures minimale et maximale quotidiennes excèdent des seuils donnés." + }, + "HEAT_SPELL_TOTAL_LENGTH": { + "long_name": "Durée totale des vagues de chaleur.", + "description": "Durée totale {freq:f} des vagues de chaleur. Une vague de chaleur se produit lorsque les moyennes sur {window} jours des températures minimale et maximale quotidiennes excèdent {thresh_tasmin} et {thresh_tasmax}, respectivement. Les {window} jours de la moyenne sont comptés dans la vague.", + "title": "Durée totale des vagues de chaleur", + "abstract": "Durée totale des vagues de chaleur. Une vague de chaleur se produit lorsque les moyennes mobiles des températures minimale et maximale quotidiennes excèdent des seuils donnés." + }, "HEAT_WAVE_FREQUENCY": { "long_name": "Nombre de séries d'au moins {window} jours consécutifs ayant des températures minimales quotidiennes au-dessus {thresh_tasmin} et températures maximales quotidiennes au-dessus {thresh_tasmax}", "description": "Nombre {freq:m} de vagues de chaleur. Une vague de chaleur se produit lorsque les températures minimale et maximale quotidiennes excèdent {thresh_tasmin} et {thresh_tasmax}, respectivement, durant au moins {window} jours.", diff --git a/xclim/indicators/atmos/_temperature.py b/xclim/indicators/atmos/_temperature.py index de22534eb..661169468 100644 --- a/xclim/indicators/atmos/_temperature.py +++ b/xclim/indicators/atmos/_temperature.py @@ -45,6 +45,9 @@ "growing_season_end", "growing_season_length", "growing_season_start", + "heat_spell_frequency", + "heat_spell_max_length", + "heat_spell_total_length", "heat_wave_frequency", "heat_wave_index", "heat_wave_max_length", @@ -201,7 +204,6 @@ class TempWithIndexing(ResamplingIndicatorWithIndexing): title="Heat wave frequency", identifier="heat_wave_frequency", units="", - standard_name="heat_wave_events", long_name="Total number of series of at least {window} consecutive days with daily minimum temperature above " "{thresh_tasmin} and daily maximum temperature above {thresh_tasmax}", description="{freq} number of heat wave events within a given period. A heat wave occurs when daily minimum and " @@ -223,7 +225,7 @@ class TempWithIndexing(ResamplingIndicatorWithIndexing): description="{freq} maximum length of heat wave events occurring within a given period. " "A heat wave occurs when daily minimum and maximum temperatures exceed {thresh_tasmin} and {thresh_tasmax}, " "respectively, over at least {window} days.", - abstract="Total duration of heat waves. A heat wave occurs when daily minimum and maximum temperatures exceed " + abstract="Maximal duration of heat waves. A heat wave occurs when daily minimum and maximum temperatures exceed " "given thresholds for a number of days.", cell_methods="", keywords="health,", @@ -240,7 +242,7 @@ class TempWithIndexing(ResamplingIndicatorWithIndexing): description="{freq} total length of heat wave events occurring within a given period. " "A heat wave occurs when daily minimum and maximum temperatures exceed {thresh_tasmin} and {thresh_tasmax}, " "respectively, over at least {window} days.", - abstract="Maximum length of heat waves. A heat wave occurs when daily minimum and maximum temperatures exceed " + abstract="Total length of heat waves. A heat wave occurs when daily minimum and maximum temperatures exceed " "given thresholds for a number of days.", cell_methods="", keywords="health,", @@ -252,7 +254,6 @@ class TempWithIndexing(ResamplingIndicatorWithIndexing): title="Heat wave index", identifier="heat_wave_index", units="days", - standard_name="heat_wave_index", long_name="Total number of days constituting events of at least {window} consecutive days " "with daily maximum temperature above {thresh}", description="{freq} total number of days that are part of a heatwave within a given period. " @@ -263,6 +264,103 @@ class TempWithIndexing(ResamplingIndicatorWithIndexing): compute=indices.heat_wave_index, ) +heat_spell_frequency = Temp( + title="Heat spell frequency", + identifier="heat_spell_frequency", + units="", + long_name="Number of heat spells", + description="{freq} number of heat spells events. A heat spell occurs when the {window}-day " + "averages of daily minimum and maximum temperatures each exceed {thresh_tasmin} and {thresh_tasmax}. " + "All days of the {window}-day period are considered part of the spell.", + abstract="Number of heat spells. A heat spell occurs when rolling averages of daily minimum and maximum temperatures exceed given " + "thresholds for a number of days.", + cell_methods="", + keywords="health,", + compute=indices.generic.bivariate_spell_length_statistics, + input={"data1": "tasmin", "data2": "tasmax"}, + parameters=dict( + spell_reducer="count", + op=">=", + window={"default": 3}, + win_reducer={"default": "mean"}, + freq={"default": "YS"}, + threshold1={ + "description": "Threshold for tasmin", + "default": "20 °C", + "name": "thresh_tasmin", + }, + threshold2={ + "description": "Threshold for tasmax", + "default": "33 °C", + "name": "thresh_tasmax", + }, + ), +) + +heat_spell_max_length = Temp( + title="Heat spell maximum length", + identifier="heat_spell_max_length", + units="days", + standard_name="spell_length_of_days_with_air_temperature_above_threshold", + long_name="Longest heat spell", + description="{freq} maximum length of heat spells. A heat spell occurs when the {window}-day " + "averages of daily minimum and maximum temperatures each exceed {thresh_tasmin} and {thresh_tasmax}. " + "All days of the {window}-day period are considered part of the spell.", + abstract="The longest heat spell of a period. A heat spell occurs when rolling averages of daily minimum and maximum temperatures exceed given " + "thresholds for a number of days.", + compute=indices.generic.bivariate_spell_length_statistics, + input={"data1": "tasmin", "data2": "tasmax"}, + parameters=dict( + spell_reducer="max", + op=">=", + window={"default": 3}, + win_reducer={"default": "mean"}, + freq={"default": "YS"}, + threshold1={ + "description": "Threshold for tasmin", + "default": "20 °C", + "name": "thresh_tasmin", + }, + threshold2={ + "description": "Threshold for tasmax", + "default": "33 °C", + "name": "thresh_tasmax", + }, + ), +) + +heat_spell_total_length = Temp( + title="Heat spell total length", + identifier="heat_spell_total_length", + units="days", + standard_name="spell_length_of_days_with_air_temperature_above_threshold", + long_name="Total length of heat spells.", + description="{freq} total length of heat spell events. " + "A heat spell occurs when the {window}-day averages of daily minimum and maximum temperatures " + "each exceed {thresh_tasmin} and {thresh_tasmax}. All days of the {window}-day period are considered part of the spell.", + abstract="Total length of heat spells. A heat spell occurs when rolling averages of daily minimum and maximum temperatures exceed given " + "thresholds for a number of days.", + compute=indices.generic.bivariate_spell_length_statistics, + input={"data1": "tasmin", "data2": "tasmax"}, + parameters=dict( + spell_reducer="sum", + op=">=", + window={"default": 3}, + win_reducer={"default": "mean"}, + freq={"default": "YS"}, + threshold1={ + "description": "Threshold for tasmin", + "default": "20 °C", + "name": "thresh_tasmin", + }, + threshold2={ + "description": "Threshold for tasmax", + "default": "33 °C", + "name": "thresh_tasmax", + }, + ), +) + hot_spell_frequency = Temp( title="Hot spell frequency", identifier="hot_spell_frequency", diff --git a/xclim/indices/generic.py b/xclim/indices/generic.py index ec549d83a..43bc85720 100644 --- a/xclim/indices/generic.py +++ b/xclim/indices/generic.py @@ -32,6 +32,7 @@ __all__ = [ "aggregate_between_dates", "binary_ops", + "bivariate_spell_length_statistics", "compare", "count_level_crossings", "count_occurrences", @@ -54,6 +55,7 @@ "select_resample_op", "spell_length", "spell_length_statistics", + "spell_mask", "statistics", "temperature_sum", "threshold_count", @@ -353,6 +355,153 @@ def get_daily_events( return events +def spell_mask( + data: xr.DataArray | Sequence[xr.DataArray], + window: int, + win_reducer: str, + op: str, + thresh: float | Sequence[float], + weights: Sequence[float] = None, + var_reducer: str = "all", +) -> xr.DataArray: + """Compute the boolean mask of data points that are part of a spell as defined by a rolling statistic. + + A day is part of a spell (True in the mask) if it is contained in any period that fulfills the condition. + + Parameters + ---------- + data: DataArray or sequence of DataArray + The input data. Can be a list, in which case the condition is checked on all variables. + See var_reducer for the latter case. + window: int + The length of the rolling window in which to compute statistics. + win_reducer: {'min', 'max', 'sum', 'mean'} + The statistics to compute on the rolling window. + op: {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"} + The comparison operator to use when finding spells. + thresh: float or sequence of floats + The threshold to compare the rolling statistics against, as ``window_stats op threshold``. + If data is a list, this must be a list of the same length with a threshold for each variable. + This function does not handle units and can't accept Quantified objects. + weights: sequence of floats + A list of weights of the same length as the window. + Only supported if `win_reducer` is "mean". + var_reducer: {'all', 'any'} + If the data is a list, the condition must either be fulfilled on *all* + or *any* variables for the period to be considered a spell. + + Returns + ------- + xr.DataArray + Same shape as ``data``, but boolean. + If ``data`` was a list, this is a DataArray of the same shape as the alignment of all variables. + """ + # Checks + if not isinstance(data, xr.DataArray): + # thus a sequence + if np.isscalar(thresh) or len(data) != len(thresh): + raise ValueError( + "When ``data`` is given as a list, ``threshold`` must be a sequence of the same length." + ) + data = xr.concat(data, "variable") + if isinstance(thresh[0], xr.DataArray): + thresh = xr.concat(thresh, "variable") + else: + thresh = xr.DataArray(thresh, dims=("variable",)) + if weights is not None: + if win_reducer != "mean": + raise ValueError( + f"Argument 'weights' is only supported if 'win_reducer' is 'mean'. Got : {win_reducer}" + ) + elif len(weights) != window: + raise ValueError( + f"Weights have a different length ({len(weights)}) than the window ({window})." + ) + weights = xr.DataArray(weights, dims=("window",)) + + if window == 1: # Fast path + is_in_spell = compare(data, op, thresh) + if not np.isscalar(thresh): + is_in_spell = getattr(is_in_spell, var_reducer)("variable") + elif (win_reducer == "min" and op in [">", ">=", "ge", "gt"]) or ( + win_reducer == "max" and op in ["`<", "<=", "le", "lt"] + ): + # Fast path for specific cases, this yields a smaller dask graph (rolling twice is expensive!) + # For these two cases, a day can't be part of a spell if it doesn't respect the condition itself + mask = compare(data, op, thresh) + if not np.isscalar(thresh): + mask = getattr(mask, var_reducer)("variable") + # We need to filter out the spells shorter than "window" + # find sequences of consecutive respected constraints + cs_s = rl._cumsum_reset_on_zero(mask) + # end of these sequences + cs_s = cs_s.where(mask.shift({"time": -1}, fill_value=0) == 0) + # propagate these end of sequences + # the `.where(mask>0, 0)` acts a stopper + is_in_spell = cs_s.where(cs_s >= window).where(mask > 0, 0).bfill("time") > 0 + else: + data_pad = data.pad(time=(0, window)) + # The spell-wise value to test + # For example "window_reducer='sum'", we want the sum over the minimum spell length (window) to be above the thresh + if weights is not None: + spell_value = data_pad.rolling(time=window).construct("window").dot(weights) + else: + spell_value = getattr(data_pad.rolling(time=window), win_reducer)() + # True at the end of a spell respecting the condition + mask = compare(spell_value, op, thresh) + if not np.isscalar(thresh): + mask = getattr(mask, var_reducer)("variable") + # True for all days part of a spell that respected the condition (shift because of the two rollings) + is_in_spell = (mask.rolling(time=window).sum() >= 1).shift(time=-(window - 1)) + # Cut back to the original size + is_in_spell = is_in_spell.isel(time=slice(0, data.time.size)) + return is_in_spell + + +def _spell_length_statistics( + data: xr.DataArray | Sequence[xr.DataArray], + thresh: float | xr.DataArray | Sequence[xr.DataArray] | Sequence[float], + window: int, + win_reducer: str, + op: str, + spell_reducer: str | Sequence[str], + freq: str, + resample_before_rl: bool = True, + **indexer, +) -> xr.DataArray | Sequence[xr.DataArray]: + if isinstance(spell_reducer, str): + spell_reducer = [spell_reducer] + is_in_spell = spell_mask(data, window, win_reducer, op, thresh).astype(np.float32) + is_in_spell = select_time(is_in_spell, **indexer) + + outs = [] + for sr in spell_reducer: + out = rl.resample_and_rl( + is_in_spell, + resample_before_rl, + rl.rle_statistics, + reducer=sr, + # The code above already ensured only spell of the minimum length are selected + window=1, + freq=freq, + ) + + if sr == "count": + outs.append(out.assign_attrs(units="")) + else: + # All other cases are statistics of the number of timesteps + outs.append( + to_agg_units( + out, + data if isinstance(data, xr.DataArray) else data[0], + "count", + ) + ) + if len(outs) == 1: + return outs[0] + return tuple(outs) + + @declare_relative_units(threshold="") def spell_length_statistics( data: xr.DataArray, @@ -368,7 +517,7 @@ def spell_length_statistics( r"""Statistics on spells lengths. A spell is when a statistic (`win_reducer`) over a minimum number (`window`) of consecutive timesteps respects a condition (`op` `thresh`). - This returns a statistic over the spell's count or length. + This returns a statistic over the spells count or lengths. Parameters ---------- @@ -383,8 +532,8 @@ def spell_length_statistics( Note that this does not matter when `window` is 1. op : {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"} Logical operator. Ex: spell_value > thresh. - spell_reducer : {'max', 'sum', 'count'} - Statistic on the spell lengths. + spell_reducer : {'max', 'sum', 'count'} or sequence thereof + Statistic on the spell lengths. If a list, multiple statistics are computed. freq : str Resampling frequency. resample_before_rl : bool @@ -423,58 +572,93 @@ def spell_length_statistics( Here, a day is part of a spell if it is in any five (5) day period where the total accumulated precipitation reaches or exceeds 20 mm. We then return the length of the longest of such spells. + + See Also + -------- + spell_mask : The lower level functions that finds spells. + bivariate_spell_length_statistics : The bivariate version of this function. """ - thresh = convert_units_to( - threshold, + thresh = convert_units_to(threshold, data, context="infer") + return _spell_length_statistics( data, - context=infer_context(standard_name=data.attrs.get("standard_name")), + thresh, + window, + win_reducer, + op, + spell_reducer, + freq, + resample_before_rl, + **indexer, ) - if window == 1: # Fast path - is_in_spell = compare(data, op, thresh) - elif (win_reducer == "min" and op in [">", ">=", "ge", "gt"]) or ( - win_reducer == "max" and op in ["`<", "<=", "le", "lt"] - ): - # Fast path for specific cases, this yields a smaller dask graph (rolling twice is expensive!) - # For these two cases, a day can't be part of a spell if it doesn't respect the condition itself - mask = compare(data, op, thresh) - # We need to filter out the spells shorter than "window" - # find sequences of consecutive respected constraints - cs_s = rl._cumsum_reset_on_zero(mask) - # end of these sequences - cs_s = cs_s.where(mask.shift({"time": -1}, fill_value=0) == 0) - # propagate these end of sequences - # the `.where(mask>0, 0)` acts a stopper - is_in_spell = cs_s.where(cs_s >= window).where(mask > 0, 0).bfill("time") > 0 - else: - data_pad = data.pad(time=(0, window)) - # The spell-wise value to test - # For example "win_reducer='sum'", we want the sum over the minimum spell length (window) to be above the thresh - spell_value = getattr(data_pad.rolling(time=window), win_reducer)() - # True at the end of a spell respecting the condition - mask = compare(spell_value, op, thresh) - # True for all days part of a spell that respected the condition (shift because of the two rollings) - is_in_spell = (mask.rolling(time=window).sum() >= 1).shift(time=-(window - 1)) - # Cut back to the original size - is_in_spell = is_in_spell.isel(time=slice(0, data.time.size)).astype(float) - is_in_spell = select_time(is_in_spell, **indexer) +@declare_relative_units(threshold1="", threshold2="") +def bivariate_spell_length_statistics( + data1: xr.DataArray, + threshold1: Quantified, + data2: xr.DataArray, + threshold2: Quantified, + window: int, + win_reducer: str, + op: str, + spell_reducer: str, + freq: str, + resample_before_rl: bool = True, + **indexer, +): + r"""Statistics on spells lengths based on two variables. + + A spell is when a statistic (`win_reducer`) over a minimum number (`window`) of consecutive timesteps respects a condition (`op` `thresh`). + This returns a statistic over the spells count or lengths. In this bivariate version, conditions on both variables must be fulfilled. + + Parameters + ---------- + data1 : xr.DataArray + First input data. + threshold1 : Quantified + Threshold to test against data1. + data2 : xr.DataArray + Second input data. + threshold2 : Quantified + Threshold to test against data2. + window : int + Minimum length of a spell. + win_reducer : {'min', 'max', 'sum', 'mean'} + Reduction along the spell length to compute the spell value. + Note that this does not matter when `window` is 1. + op : {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"} + Logical operator. Ex: spell_value > thresh. + spell_reducer : {'max', 'sum', 'count'} or sequence thereof + Statistic on the spell lengths. If a list, multiple statistics are computed. + freq : str + Resampling frequency. + resample_before_rl : bool + Determines if the resampling should take place before or after the run + length encoding (or a similar algorithm) is applied to runs. + \*\*indexer + Indexing parameters to compute the indicator on a temporal subset of the data. + It accepts the same arguments as :py:func:`xclim.indices.generic.select_time`. + Indexing is done after finding the days part of a spell, but before taking the spell statistics. - out = rl.resample_and_rl( - is_in_spell, + See Also + -------- + spell_length_statistics: The univariate version. + spell_mask : The lower level functions that finds spells. + """ + thresh1 = convert_units_to(threshold1, data1, context="infer") + thresh2 = convert_units_to(threshold2, data2, context="infer") + return _spell_length_statistics( + [data1, data2], + [thresh1, thresh2], + window, + win_reducer, + op, + spell_reducer, + freq, resample_before_rl, - rl.rle_statistics, - reducer=spell_reducer, - # The code above already ensured only spell of the minimum length are selected - window=1, - freq=freq, + **indexer, ) - if spell_reducer == "count": - return out.assign_attrs(units="") - # All other cases are statistics of the number of timesteps - return to_agg_units(out, data, "count") - @declare_relative_units(thresh="") def season(