From 55ccc8e831507cec30889fdf17ff76a5d6144b26 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Thu, 15 Aug 2024 10:20:22 -0400 Subject: [PATCH 01/14] New spell generic - Heat spell - change param name in Ind --- xclim/core/formatting.py | 2 + xclim/core/indicator.py | 14 +- xclim/data/fr.json | 18 +++ xclim/indicators/atmos/_temperature.py | 106 +++++++++++- xclim/indices/generic.py | 214 +++++++++++++++++++++---- 5 files changed, 314 insertions(+), 40 deletions(-) diff --git a/xclim/core/formatting.py b/xclim/core/formatting.py index aaa18025b..e9883ec19 100644 --- a/xclim/core/formatting.py +++ b/xclim/core/formatting.py @@ -582,6 +582,8 @@ def _gen_parameters_section( """ section = "Parameters\n----------\n" for name, param in parameters.items(): + if param.name is not param._empty: + name = param.name desc_str = param.description if param.kind == InputKind.FREQ_STR: desc_str += ( diff --git a/xclim/core/indicator.py b/xclim/core/indicator.py index 3b68e09c2..75fc6a3ee 100644 --- a/xclim/core/indicator.py +++ b/xclim/core/indicator.py @@ -78,8 +78,9 @@ units: # Only valid if "compute" points to a generic function default : description: + name : # A name to use when generating the docstring and signature. kind: # Override the parameter kind. - # This is mostly useful for transforming an optional variable into a required one by passing ``kind: 0``. + # This is mostly useful for transforming an optional variable into a required one by passing ``kind: 0``. ... ... # and so on. @@ -179,7 +180,7 @@ class _empty: class Parameter: """Class for storing an indicator's controllable parameter. - For retrocompatibility, this class implements a "getitem" and a special "contains". + For convenience, this class implements a special "contains". Example ------- @@ -195,6 +196,7 @@ class Parameter: _empty = _empty kind: InputKind + name: str = _empty default: Any = _empty_default description: str = "" units: str = _empty @@ -788,9 +790,10 @@ def _gen_signature(self): ) ) else: + show_name = name if meta.name is meta._empty else meta.name parameters.append( _Parameter( - name, + show_name, kind=_Parameter.KEYWORD_ONLY, default=meta.default, annotation=compute_sig.parameters[name].annotation, @@ -906,6 +909,11 @@ def _parse_variables_from_call(self, args, kwds) -> tuple[OrderedDict, dict]: ba = self.__signature__.bind(*args, **kwds) ba.apply_defaults() + # Assign parameters with different signature-names correctly + for name, param in self._all_parameters.items(): + if param.name is not param._empty: + ba.arguments[name] = ba.arguments.pop(param.name) + # Assign inputs passed as strings from ds. self._assign_named_args(ba) diff --git a/xclim/data/fr.json b/xclim/data/fr.json index 5c395e112..07c2e0933 100644 --- a/xclim/data/fr.json +++ b/xclim/data/fr.json @@ -335,6 +335,24 @@ "title": "Nombre de jours ayant des températures minimale et maximale quotidiennes au-dessus de seuils donnés", "abstract": "Nombre de jours où les températures maximale et minimale quotidiennes sont au-dessus de seuils donnés." }, + "HEAT_SPELL_FREQUENCY": { + "long_name": "Nombre de vagues de chaleur.", + "description": "Nombre {freq:m} de vagues de chaleur. Une vague de chaleur se produit lorsque les moyennes sur {window} jours des températures minimale et maximale quotidiennes excèdent {thresh_tasmin} et {thresh_tasmax}, respectivement. Les {windows} jours de la moyenne sont comptés dans la vague.", + "title": "Fréquence des vagues de chaleur", + "abstract": "Nombre de vagues de chaleur. Une vague de chaleur se produit lorsque les moyennes mobiles des températures minimale et maximale quotidiennes excèdent des seuils donnés." + }, + "HEAT_SPELL_MAX_LENGTH": { + "long_name": "Vague de chaleur la plus longue.", + "description": "Durée maximale {freq:f} des vagues de chaleur. Une vague de chaleur se produit lorsque les moyennes sur {window} jours des températures minimale et maximale quotidiennes excèdent {thresh_tasmin} et {thresh_tasmax}, respectivement. Les {windows} jours de la moyenne sont comptés dans la vague.", + "title": "Durée maximale des vagues de chaleur", + "abstract": "Durée maximale des vagues de chaleur. Une vague de chaleur se produit lorsque les moyennes mobiles des températures minimale et maximale quotidiennes excèdent des seuils donnés." + }, + "HEAT_SPELL_TOTAL_LENGTH": { + "long_name": "Durée totale des vagues de chaleur.", + "description": "Durée totale {freq:f} des vagues de chaleur. Une vague de chaleur se produit lorsque les moyennes sur {window} jours des températures minimale et maximale quotidiennes excèdent {thresh_tasmin} et {thresh_tasmax}, respectivement. Les {windows} jours de la moyenne sont comptés dans la vague.", + "title": "Durée totale des vagues de chaleur", + "abstract": "Durée totale des vagues de chaleur. Une vague de chaleur se produit lorsque les moyennes mobiles des températures minimale et maximale quotidiennes excèdent des seuils donnés." + }, "HEAT_WAVE_FREQUENCY": { "long_name": "Nombre de séries d'au moins {window} jours consécutifs ayant des températures minimales quotidiennes au-dessus {thresh_tasmin} et températures maximales quotidiennes au-dessus {thresh_tasmax}", "description": "Nombre {freq:m} de vagues de chaleur. Une vague de chaleur se produit lorsque les températures minimale et maximale quotidiennes excèdent {thresh_tasmin} et {thresh_tasmax}, respectivement, durant au moins {window} jours.", diff --git a/xclim/indicators/atmos/_temperature.py b/xclim/indicators/atmos/_temperature.py index 54caf2c0c..6140c2abf 100644 --- a/xclim/indicators/atmos/_temperature.py +++ b/xclim/indicators/atmos/_temperature.py @@ -45,6 +45,9 @@ "growing_season_end", "growing_season_length", "growing_season_start", + "heat_spell_frequency", + "heat_spell_max_length", + "heat_spell_total_length", "heat_wave_frequency", "heat_wave_index", "heat_wave_max_length", @@ -201,7 +204,6 @@ class TempWithIndexing(ResamplingIndicatorWithIndexing): title="Heat wave frequency", identifier="heat_wave_frequency", units="", - standard_name="heat_wave_events", long_name="Total number of series of at least {window} consecutive days with daily minimum temperature above " "{thresh_tasmin} and daily maximum temperature above {thresh_tasmax}", description="{freq} number of heat wave events within a given period. A heat wave occurs when daily minimum and " @@ -223,7 +225,7 @@ class TempWithIndexing(ResamplingIndicatorWithIndexing): description="{freq} maximum length of heat wave events occurring within a given period. " "A heat wave occurs when daily minimum and maximum temperatures exceed {thresh_tasmin} and {thresh_tasmax}, " "respectively, over at least {window} days.", - abstract="Total duration of heat waves. A heat wave occurs when daily minimum and maximum temperatures exceed " + abstract="Maximal duration of heat waves. A heat wave occurs when daily minimum and maximum temperatures exceed " "given thresholds for a number of days.", cell_methods="", keywords="health,", @@ -240,7 +242,7 @@ class TempWithIndexing(ResamplingIndicatorWithIndexing): description="{freq} total length of heat wave events occurring within a given period. " "A heat wave occurs when daily minimum and maximum temperatures exceed {thresh_tasmin} and {thresh_tasmax}, " "respectively, over at least {window} days.", - abstract="Maximum length of heat waves. A heat wave occurs when daily minimum and maximum temperatures exceed " + abstract="Total length of heat waves. A heat wave occurs when daily minimum and maximum temperatures exceed " "given thresholds for a number of days.", cell_methods="", keywords="health,", @@ -252,7 +254,6 @@ class TempWithIndexing(ResamplingIndicatorWithIndexing): title="Heat wave index", identifier="heat_wave_index", units="days", - standard_name="heat_wave_index", long_name="Total number of days constituting events of at least {window} consecutive days " "with daily maximum temperature above {thresh}", description="{freq} total number of days that are part of a heatwave within a given period. " @@ -263,6 +264,103 @@ class TempWithIndexing(ResamplingIndicatorWithIndexing): compute=indices.heat_wave_index, ) +heat_spell_frequency = Temp( + title="Heat spell frequency", + identifier="heat_spell_frequency", + units="", + long_name="Number of heat spells", + description="{freq} number of heat spells events within a given period. A heat spell occurs when the {window}-day " + "averages of daily minimum and maximum temperatures each exceed {threshold1} and {threshold2}. " + "All days of the {window}-day period are considered part of the spell.", + abstract="Number of heat spells. A heat spell occurs when rolling averages of daily minimum and maximum temperatures exceed given " + "thresholds for a number of days.", + cell_methods="", + keywords="health,", + compute=indices.generic.bivariate_spell_length_statistics, + input={"data1": "tasmin", "data2": "tasmax"}, + parameters=dict( + win_reducer="mean", + spell_reducer="count", + window={"default": 3}, + op={"default": ">="}, + freq={"default": "YS"}, + threshold1={ + "description": "Threshold for tasmin", + "default": "22 °C", + "name": "thresh_tasmin", + }, + threshold2={ + "description": "Threshold for tasmax", + "default": "30 °C", + "name": "thresh_tasmax", + }, + ), +) + +heat_spell_max_length = Temp( + title="Heat spell maximum length", + identifier="heat_spell_max_length", + units="days", + standard_name="spell_length_of_days_with_air_temperature_above_threshold", + long_name="Longest heat spell", + description="{freq} maximum length of heat spells within a given period. A heat spell occurs when the {window}-day " + "averages of daily minimum and maximum temperatures each exceed {threshold1} and {threshold2}. " + "All days of the {window}-day period are considered part of the spell.", + abstract="The longest heat spell of a period. A heat spell occurs when rolling averages of daily minimum and maximum temperatures exceed given " + "thresholds for a number of days.", + compute=indices.generic.bivariate_spell_length_statistics, + input={"data1": "tasmin", "data2": "tasmax"}, + parameters=dict( + win_reducer="mean", + spell_reducer="max", + window={"default": 3}, + op={"default": ">="}, + freq={"default": "YS"}, + threshold1={ + "description": "Threshold for tasmin", + "default": "22 °C", + "name": "thresh_tasmin", + }, + threshold2={ + "description": "Threshold for tasmax", + "default": "30 °C", + "name": "thresh_tasmax", + }, + ), +) + +heat_spell_total_length = Temp( + title="Heat spell total length", + identifier="heat_spell_total_length", + units="days", + standard_name="spell_length_of_days_with_air_temperature_above_threshold", + long_name="Total length of heat spells.", + description="{freq} total length of heat spell events occurring within a given period. " + "A heat spell occurs when the {window}-day averages of daily minimum and maximum temperatures " + "each exceed {threshold1} and {threshold2}. All days of the {window}-day period are considered part of the spell.", + abstract="Total length of heat spells. A heat spell occurs when rolling averages of daily minimum and maximum temperatures exceed given " + "thresholds for a number of days.", + compute=indices.generic.bivariate_spell_length_statistics, + input={"data1": "tasmin", "data2": "tasmax"}, + parameters=dict( + win_reducer="mean", + spell_reducer="max", + window={"default": 3}, + op={"default": ">="}, + freq={"default": "YS"}, + threshold1={ + "description": "Threshold for tasmin", + "default": "22 °C", + "name": "thresh_tasmin", + }, + threshold2={ + "description": "Threshold for tasmax", + "default": "30 °C", + "name": "thresh_tasmax", + }, + ), +) + hot_spell_frequency = Temp( title="Hot spell frequency", identifier="hot_spell_frequency", diff --git a/xclim/indices/generic.py b/xclim/indices/generic.py index 282a4f40c..4b9aab4e1 100644 --- a/xclim/indices/generic.py +++ b/xclim/indices/generic.py @@ -354,6 +354,105 @@ def get_daily_events( return events +def spell_mask( + data: xarray.DataArray | Sequence[xarray.DataArray], + window: int, + win_reducer: str, + op: str, + thresh: float | Sequence[float], + weights: Sequence[float] = None, + var_reducer: str = "all", +) -> xarray.DataArray: + """Compute the boolean mask of data points that are part of a spell as defined by a rolling statistic. + + A day is part of a spell (True in the mask) if it is contained in any period that fulfills the condition. + + Parameters + ---------- + data: DataArray or sequence of DataArray + The input data. Can be a list, in which case the condition is checked on all variables. + See var_reducer for the latter case. + window: int + The length of the rolling window in which to compute statistics. + win_reducer: {'min', 'max', 'sum', 'mean'} + The statistics to compute on the rolling window. + op: {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"} + The comparison operator to use when finding spells. + thresh: float or sequence of floats + The threshold to compare the rolling statistics against, as ``window_stats op threshold``. + If data is a list, this must be a list of the same length with a threshold for each variable. + weights: sequence of floats + A list of weights of the same length as the window. + Only supported if `win_reducer` is "mean". + var_reducer: {'all', 'any'} + If the data is a list, the condition must either be fulfilled on *all* + or *any* variables for the period to be considered a spell. + + Returns + ------- + xarray.DataArray + Same shape as ``data``, but boolean. + If ``data`` was a Dataset, this is a DataArray of the same shape as the alignment of all variables. + """ + # Checks + if not isinstance(data, xarray.DataArray): + # thus a sequence + if np.isscalar(thresh) or len(data) != len(thresh): + raise ValueError( + "When ``data`` is given as a list, ``threshold`` must be a sequence of the same length." + ) + data = xarray.concat(data, "variable") + thresh = xarray.DataArray(thresh, dims=("variable",)) + if weights is not None: + if win_reducer != "mean": + raise ValueError( + f"Argument 'weights' is only supported if 'window_reducer' is 'mean'. Got : {win_reducer}" + ) + elif len(weights) != window: + raise ValueError( + f"Weights have a different length ({len(weights)}) than the window ({window})." + ) + weights = xarray.DataArray(weights, dims=("window",)) + + if window == 1: # Fast path + mask = compare(data, op, thresh) + if not np.isscalar(thresh): + is_in_spell = getattr(mask, var_reducer)("variable") + elif (win_reducer == "min" and op in [">", ">=", "ge", "gt"]) or ( + win_reducer == "max" and op in ["`<", "<=", "le", "lt"] + ): + # Fast path for specific cases, this yields a smaller dask graph (rolling twice is expensive!) + # For these two cases, a day can't be part of a spell if it doesn't respect the condition itself + mask = compare(data, op, thresh) + if not np.isscalar(thresh): + mask = getattr(mask, var_reducer)("variable") + # We need to filter out the spells shorter than "window" + # find sequences of consecutive respected constraints + cs_s = rl._cumsum_reset_on_zero(mask) + # end of these sequences + cs_s = cs_s.where(mask.shift({"time": -1}, fill_value=0) == 0) + # propagate these end of sequences + # the `.where(mask>0, 0)` acts a stopper + is_in_spell = cs_s.where(cs_s >= window).where(mask > 0, 0).bfill("time") > 0 + else: + data_pad = data.pad(time=(0, window)) + # The spell-wise value to test + # For example "window_reducer='sum'", we want the sum over the minimum spell length (window) to be above the thresh + if weights is not None: + spell_value = data_pad.rolling(time=window).construct("window").dot(weights) + else: + spell_value = getattr(data_pad.rolling(time=window), win_reducer)() + # True at the end of a spell respecting the condition + mask = compare(spell_value, op, thresh) + if not np.isscalar(thresh): + mask = getattr(mask, var_reducer)("variable") + # True for all days part of a spell that respected the condition (shift because of the two rollings) + is_in_spell = (mask.rolling(time=window).sum() >= 1).shift(time=-(window - 1)) + # Cut back to the original size + is_in_spell = is_in_spell.isel(time=slice(0, data.time.size)) + return is_in_spell + + @declare_relative_units(threshold="") def spell_length_statistics( data: xarray.DataArray, @@ -369,7 +468,7 @@ def spell_length_statistics( r"""Statistics on spells lengths. A spell is when a statistic (`win_reducer`) over a minimum number (`window`) of consecutive timesteps respects a condition (`op` `thresh`). - This returns a statistic over the spell's count or length. + This returns a statistic over the spells count or lengths. Parameters ---------- @@ -424,41 +523,90 @@ def spell_length_statistics( Here, a day is part of a spell if it is in any five (5) day period where the total accumulated precipitation reaches or exceeds 20 mm. We then return the length of the longest of such spells. + + See Also + -------- + spell_mask : The lower level functions that finds spells. + bivariate_spell_length_statistics : The bivariate version of this function. """ - thresh = convert_units_to( - threshold, - data, - context=infer_context(standard_name=data.attrs.get("standard_name")), + thresh = convert_units_to(threshold, data, context="infer") + is_in_spell = spell_mask(data, window, win_reducer, op, thresh).astype(np.float32) + is_in_spell = select_time(is_in_spell, **indexer) + + out = rl.resample_and_rl( + is_in_spell, + resample_before_rl, + rl.rle_statistics, + reducer=spell_reducer, + # The code above already ensured only spell of the minimum length are selected + window=1, + freq=freq, ) - if window == 1: # Fast path - is_in_spell = compare(data, op, thresh) - elif (win_reducer == "min" and op in [">", ">=", "ge", "gt"]) or ( - win_reducer == "max" and op in ["`<", "<=", "le", "lt"] - ): - # Fast path for specific cases, this yields a smaller dask graph (rolling twice is expensive!) - # For these two cases, a day can't be part of a spell if it doesn't respect the condition itself - mask = compare(data, op, thresh) - # We need to filter out the spells shorter than "window" - # find sequences of consecutive respected constraints - cs_s = rl._cumsum_reset_on_zero(mask) - # end of these sequences - cs_s = cs_s.where(mask.shift({"time": -1}, fill_value=0) == 0) - # propagate these end of sequences - # the `.where(mask>0, 0)` acts a stopper - is_in_spell = cs_s.where(cs_s >= window).where(mask > 0, 0).bfill("time") > 0 - else: - data_pad = data.pad(time=(0, window)) - # The spell-wise value to test - # For example "win_reducer='sum'", we want the sum over the minimum spell length (window) to be above the thresh - spell_value = getattr(data_pad.rolling(time=window), win_reducer)() - # True at the end of a spell respecting the condition - mask = compare(spell_value, op, thresh) - # True for all days part of a spell that respected the condition (shift because of the two rollings) - is_in_spell = (mask.rolling(time=window).sum() >= 1).shift(time=-(window - 1)) - # Cut back to the original size - is_in_spell = is_in_spell.isel(time=slice(0, data.time.size)).astype(float) + if spell_reducer == "count": + return out.assign_attrs(units="") + # All other cases are statistics of the number of timesteps + return to_agg_units(out, data, "count") + +@declare_relative_units(threshold1="", threshold2="") +def bivariate_spell_length_statistics( + data1: xarray.DataArray, + threshold1: Quantified, + data2: xarray.DataArray, + threshold2: Quantified, + window: int, + win_reducer: str, + op: str, + spell_reducer: str, + freq: str, + resample_before_rl: bool = True, + **indexer, +): + r"""Statistics on spells lengths based on two variables. + + A spell is when a statistic (`win_reducer`) over a minimum number (`window`) of consecutive timesteps respects a condition (`op` `thresh`). + This returns a statistic over the spells count or lengths. In this bivariate version, conditions on both variables must be fulfilled. + + Parameters + ---------- + data1 : xr.DataArray + First input data. + threshold1 : Quantified + Threshold to test against data1. + data2 : xr.DataArray + Second input data. + threshold2 : Quantified + Threshold to test against data2. + window : int + Minimum length of a spell. + win_reducer : {'min', 'max', 'sum', 'mean'} + Reduction along the spell length to compute the spell value. + Note that this does not matter when `window` is 1. + op : {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"} + Logical operator. Ex: spell_value > thresh. + spell_reducer : {'max', 'sum', 'count'} + Statistic on the spell lengths. + freq : str + Resampling frequency. + resample_before_rl : bool + Determines if the resampling should take place before or after the run + length encoding (or a similar algorithm) is applied to runs. + \*\*indexer + Indexing parameters to compute the indicator on a temporal subset of the data. + It accepts the same arguments as :py:func:`xclim.indices.generic.select_time`. + Indexing is done after finding the days part of a spell, but before taking the spell statistics. + + See Also + -------- + spell_length_statistics: The univariate version. + spell_mask : The lower level functions that finds spells. + """ + thresh1 = convert_units_to(threshold1, data1, context="infer") + thresh2 = convert_units_to(threshold2, data2, context="infer") + is_in_spell = spell_mask( + [data1, data2], window, win_reducer, op, [thresh1, thresh2], var_reducer="all" + ).astype(np.float32) is_in_spell = select_time(is_in_spell, **indexer) out = rl.resample_and_rl( @@ -474,7 +622,7 @@ def spell_length_statistics( if spell_reducer == "count": return out.assign_attrs(units="") # All other cases are statistics of the number of timesteps - return to_agg_units(out, data, "count") + return to_agg_units(out, data1, "count") @declare_relative_units(thresh="") From e5cb1dbff2432e72e7fe54d1d7cc3f7915a3f51a Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Thu, 15 Aug 2024 10:41:40 -0400 Subject: [PATCH 02/14] update thresholds --- xclim/indicators/atmos/_temperature.py | 12 ++++++------ xclim/indices/generic.py | 3 ++- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/xclim/indicators/atmos/_temperature.py b/xclim/indicators/atmos/_temperature.py index 6140c2abf..ae5e5122b 100644 --- a/xclim/indicators/atmos/_temperature.py +++ b/xclim/indicators/atmos/_temperature.py @@ -286,12 +286,12 @@ class TempWithIndexing(ResamplingIndicatorWithIndexing): freq={"default": "YS"}, threshold1={ "description": "Threshold for tasmin", - "default": "22 °C", + "default": "20 °C", "name": "thresh_tasmin", }, threshold2={ "description": "Threshold for tasmax", - "default": "30 °C", + "default": "33 °C", "name": "thresh_tasmax", }, ), @@ -318,12 +318,12 @@ class TempWithIndexing(ResamplingIndicatorWithIndexing): freq={"default": "YS"}, threshold1={ "description": "Threshold for tasmin", - "default": "22 °C", + "default": "20 °C", "name": "thresh_tasmin", }, threshold2={ "description": "Threshold for tasmax", - "default": "30 °C", + "default": "33 °C", "name": "thresh_tasmax", }, ), @@ -350,12 +350,12 @@ class TempWithIndexing(ResamplingIndicatorWithIndexing): freq={"default": "YS"}, threshold1={ "description": "Threshold for tasmin", - "default": "22 °C", + "default": "20 °C", "name": "thresh_tasmin", }, threshold2={ "description": "Threshold for tasmax", - "default": "30 °C", + "default": "33 °C", "name": "thresh_tasmax", }, ), diff --git a/xclim/indices/generic.py b/xclim/indices/generic.py index 4b9aab4e1..c246f1776 100644 --- a/xclim/indices/generic.py +++ b/xclim/indices/generic.py @@ -381,6 +381,7 @@ def spell_mask( thresh: float or sequence of floats The threshold to compare the rolling statistics against, as ``window_stats op threshold``. If data is a list, this must be a list of the same length with a threshold for each variable. + This function does not handle units and can't accept Quantified objects. weights: sequence of floats A list of weights of the same length as the window. Only supported if `win_reducer` is "mean". @@ -392,7 +393,7 @@ def spell_mask( ------- xarray.DataArray Same shape as ``data``, but boolean. - If ``data`` was a Dataset, this is a DataArray of the same shape as the alignment of all variables. + If ``data`` was a list, this is a DataArray of the same shape as the alignment of all variables. """ # Checks if not isinstance(data, xarray.DataArray): From 214f7651615ee662cfb0b9d3a537bfd0dd769cc1 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Thu, 15 Aug 2024 11:32:54 -0400 Subject: [PATCH 03/14] Add tests for heat spells and spell mask --- tests/test_formatting.py | 2 +- tests/test_generic.py | 74 +++++++++++++++++++++++++++++++++ tests/test_temperature.py | 86 +++++++++++++++++++++++++++++++++++++++ xclim/indices/generic.py | 4 +- 4 files changed, 164 insertions(+), 2 deletions(-) diff --git a/tests/test_formatting.py b/tests/test_formatting.py index 98ffb2563..7f58c2a67 100644 --- a/tests/test_formatting.py +++ b/tests/test_formatting.py @@ -35,7 +35,7 @@ def test_indicator_docstring(): assert ( doc[41] == " Total number of series of at least {window} consecutive days with daily minimum temperature above " - "{thresh_tasmin} and daily maximum temperature above {thresh_tasmax} (heat_wave_events), " + "{thresh_tasmin} and daily maximum temperature above {thresh_tasmax}, " "with additional attributes: **description**: {freq} number of heat wave events within a given period. " "A heat wave occurs when daily minimum and maximum temperatures exceed {thresh_tasmin} and {thresh_tasmax}, " "respectively, over at least {window} days." diff --git a/tests/test_generic.py b/tests/test_generic.py index 8448a5607..c716a7d86 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -663,3 +663,77 @@ def test_select_time_errors(self): with pytest.raises(TypeError): select_time(da, doy_bounds=(300, 203, 202)) + + +class TestSpellMask: + def test_single_variable(self): + data = xr.DataArray([0, 1, 2, 3, 2, 1, 0, 0], dims=("time",)) + + out = generic.spell_mask(data, 3, "min", ">=", 2) + np.testing.assert_array_equal( + out, np.array([0, 0, 1, 1, 1, 0, 0, 0]).astype(bool) + ) + + out = generic.spell_mask(data, 3, "max", ">=", 2) + np.testing.assert_array_equal( + out, np.array([1, 1, 1, 1, 1, 1, 1, 0]).astype(bool) + ) + + out = generic.spell_mask(data, 2, "mean", ">=", 2) + np.testing.assert_array_equal( + out, np.array([0, 0, 1, 1, 1, 0, 0, 0]).astype(bool) + ) + + out = generic.spell_mask(data, 3, "mean", ">", 2, weights=[0.2, 0.4, 0.4]) + np.testing.assert_array_equal( + out, np.array([0, 1, 1, 1, 1, 0, 0, 0]).astype(bool) + ) + + def test_multiple_variables(self): + data1 = xr.DataArray([0, 1, 2, 3, 2, 1, 0, 0], dims=("time",)) + data2 = xr.DataArray([1, 2, 3, 2, 1, 0, 0, 0], dims=("time",)) + + out = generic.spell_mask([data1, data2], 3, "min", ">=", [2, 2]) + np.testing.assert_array_equal( + out, np.array([0, 0, 0, 0, 0, 0, 0, 0]).astype(bool) + ) + + out = generic.spell_mask( + [data1, data2], 3, "min", ">=", [2, 2], var_reducer="any" + ) + np.testing.assert_array_equal( + out, np.array([0, 1, 1, 1, 1, 0, 0, 0]).astype(bool) + ) + + out = generic.spell_mask([data1, data2], 2, "mean", ">=", [2, 2]) + np.testing.assert_array_equal( + out, np.array([0, 0, 1, 1, 0, 0, 0, 0]).astype(bool) + ) + + out = generic.spell_mask( + [data1, data2], 3, "mean", ">", [2, 1.5], weights=[0.2, 0.4, 0.4] + ) + np.testing.assert_array_equal( + out, np.array([0, 1, 1, 1, 1, 0, 0, 0]).astype(bool) + ) + + def test_errors(self): + data = xr.DataArray([0, 1, 2, 3, 2, 1, 0, 0], dims=("time",)) + + # Threshold must be seq + with pytest.raises(ValueError, match="must be a sequence of the same length"): + generic.spell_mask([data, data], 3, "min", "<=", 2) + + # Threshold must be same length + with pytest.raises(ValueError, match="must be a sequence of the same length"): + generic.spell_mask([data, data], 3, "min", "<=", [2]) + + # Weights must have win_reducer = 'mean' + with pytest.raises( + ValueError, match="is only supported if 'win_reducer' is 'mean'" + ): + generic.spell_mask(data, 3, "min", "<=", 2, weights=[1, 2, 3]) + + # Weights must have same length as window + with pytest.raises(ValueError, match="Weights have a different length"): + generic.spell_mask(data, 3, "mean", "<=", 2, weights=[1, 2]) diff --git a/tests/test_temperature.py b/tests/test_temperature.py index 23ed23a0d..ad1a1c401 100644 --- a/tests/test_temperature.py +++ b/tests/test_temperature.py @@ -610,6 +610,92 @@ def test_3d_data_with_nans(self, open_dataset): assert np.isnan(gdd.values[0, -1, -1]) +class TestHeatSpellFrequency: + def test_1d(self, tasmax_series, tasmin_series): + tn1 = np.zeros(366) + tx1 = np.zeros(366) + tn1[:10] = np.array([20, 23, 23, 23, 20, 20, 23, 23, 23, 23]) + tx1[:10] = np.array([29, 31, 31, 31, 28, 28, 31, 31, 31, 31]) + + tn = tasmin_series(tn1 + K2C, start="1/1/2000") + tx = tasmax_series(tx1 + K2C, start="1/1/2000") + + hsf = atmos.heat_spell_frequency( + tn, tx, thresh_tasmin="22 C", thresh_tasmax="30 C", freq="YS", op=">" + ) + np.testing.assert_allclose(hsf.values[:1], 2) + + hsf = atmos.heat_spell_frequency( + tn, tx, thresh_tasmin="22 C", thresh_tasmax="30 C", window=5, freq="YS" + ) + np.testing.assert_allclose(hsf.values[:1], 1) + + # no hs + hsf = atmos.heat_spell_frequency( + tn, tx, thresh_tasmin="40 C", thresh_tasmax="40 C", freq="YS" + ) + np.testing.assert_allclose(hsf.values[:1], 0) + + +class TestHeatSpellMaxLength: + def test_1d(self, tasmax_series, tasmin_series): + tn1 = np.zeros(366) + tx1 = np.zeros(366) + tn1[:10] = np.array([20, 23, 23, 23, 20, 20, 23, 23, 23, 23]) + tx1[:10] = np.array([29, 31, 31, 31, 28, 28, 31, 31, 31, 31]) + + tn = tasmin_series(tn1 + K2C, start="1/1/2000") + tx = tasmax_series(tx1 + K2C, start="1/1/2000") + + hsf = atmos.heat_spell_max_length( + tn, tx, thresh_tasmin="22 C", thresh_tasmax="30 C", freq="YS", op=">" + ) + np.testing.assert_allclose(hsf.values[:1], 4) + + hsf = atmos.heat_spell_max_length( + tn, + tx, + thresh_tasmin="22 C", + thresh_tasmax="30 C", + window=5, + freq="YS", + ) + np.testing.assert_allclose(hsf.values[:1], 5) + + # no hs + hsf = atmos.heat_spell_max_length( + tn, tx, thresh_tasmin="40 C", thresh_tasmax="40 C", freq="YS" + ) + np.testing.assert_allclose(hsf.values[:1], 0) + + +class TestHeatSpellTotalLength: + def test_1d(self, tasmax_series, tasmin_series): + tn1 = np.zeros(366) + tx1 = np.zeros(366) + tn1[:10] = np.array([20, 23, 23, 23, 20, 20, 23, 23, 23, 23]) + tx1[:10] = np.array([29, 31, 31, 31, 28, 28, 31, 31, 31, 31]) + + tn = tasmin_series(tn1 + K2C, start="1/1/2000") + tx = tasmax_series(tx1 + K2C, start="1/1/2000") + + hsf = atmos.heat_spell_total_length( + tn, tx, thresh_tasmin="22 C", thresh_tasmax="30 C", freq="YS", op=">" + ) + np.testing.assert_allclose(hsf.values[:1], 7) + + hsf = atmos.heat_spell_total_length( + tn, tx, thresh_tasmin="20 C", thresh_tasmax="30 C", window=5, freq="YS" + ) + np.testing.assert_allclose(hsf.values[:1], 5) + + # no hs + hsf = atmos.heat_spell_total_length( + tn, tx, thresh_tasmin="40 C", thresh_tasmax="40 C", freq="YS" + ) + np.testing.assert_allclose(hsf.values[:1], 0) + + class TestHeatWaveFrequency: def test_1d(self, tasmax_series, tasmin_series): tn1 = np.zeros(366) diff --git a/xclim/indices/generic.py b/xclim/indices/generic.py index c246f1776..459103dd6 100644 --- a/xclim/indices/generic.py +++ b/xclim/indices/generic.py @@ -33,6 +33,7 @@ __all__ = [ "aggregate_between_dates", "binary_ops", + "bivariate_spell_length_statistics", "compare", "count_level_crossings", "count_occurrences", @@ -55,6 +56,7 @@ "select_resample_op", "spell_length", "spell_length_statistics", + "spell_mask", "statistics", "temperature_sum", "threshold_count", @@ -407,7 +409,7 @@ def spell_mask( if weights is not None: if win_reducer != "mean": raise ValueError( - f"Argument 'weights' is only supported if 'window_reducer' is 'mean'. Got : {win_reducer}" + f"Argument 'weights' is only supported if 'win_reducer' is 'mean'. Got : {win_reducer}" ) elif len(weights) != window: raise ValueError( From 38bd87a7d703be1502ef3a1c0ee663d5f2b56c88 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Thu, 15 Aug 2024 11:46:42 -0400 Subject: [PATCH 04/14] Mirror signature of other spells - upd changes --- CHANGELOG.rst | 13 +++++++++++++ tests/test_temperature.py | 16 ++++++++++++---- xclim/indicators/atmos/_temperature.py | 23 ++++++++++++++++------- 3 files changed, 41 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 5573a3c3e..a850d5983 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,6 +2,19 @@ Changelog ========= +v0.53.0 (unreleased) +-------------------- +Contributors to this version: Pascal Bourgault (:user:`aulemahal`). + +New indicators +^^^^^^^^^^^^^^ +* New ``heat_spell_frequency``, ``heat_spell_max_length`` and ``heat_spell_total_length`` : spell length statistics on a bivariate condition that uses the average over a window by default. + +New features and enhancements +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +* New generic ``xclim.indices.generic.spell_mask`` that returns a mask of which days are part of a spell. Supports multivariate conditions and weights. Used in new generic index ``xclim.indices.generic.bivariate_spell_length_statistics`` that extends ``spell_length_statistics`` to two variables. +* Indicator parameters can now be assigned a new name that is used in the indicator's signature and docstring. + v0.52.0 (2024-08-08) -------------------- Contributors to this version: David Huard (:user:`huard`), Trevor James Smith (:user:`Zeitsperre`), Hui-Min Wang (:user:`Hem-W`), Éric Dupuis (:user:`coxipi`), Sarah Gammon (:user:`SarahG-579462`), Pascal Bourgault (:user:`aulemahal`), Juliette Lavoie (:user:`juliettelavoie`), Adrien Lamarche (:user:`LamAdr`). diff --git a/tests/test_temperature.py b/tests/test_temperature.py index ad1a1c401..45f3c94b7 100644 --- a/tests/test_temperature.py +++ b/tests/test_temperature.py @@ -621,7 +621,11 @@ def test_1d(self, tasmax_series, tasmin_series): tx = tasmax_series(tx1 + K2C, start="1/1/2000") hsf = atmos.heat_spell_frequency( - tn, tx, thresh_tasmin="22 C", thresh_tasmax="30 C", freq="YS", op=">" + tn, + tx, + thresh_tasmin="22.1 C", + thresh_tasmax="30.1 C", + freq="YS", ) np.testing.assert_allclose(hsf.values[:1], 2) @@ -648,7 +652,11 @@ def test_1d(self, tasmax_series, tasmin_series): tx = tasmax_series(tx1 + K2C, start="1/1/2000") hsf = atmos.heat_spell_max_length( - tn, tx, thresh_tasmin="22 C", thresh_tasmax="30 C", freq="YS", op=">" + tn, + tx, + thresh_tasmin="22.1 C", + thresh_tasmax="30.1 C", + freq="YS", ) np.testing.assert_allclose(hsf.values[:1], 4) @@ -680,12 +688,12 @@ def test_1d(self, tasmax_series, tasmin_series): tx = tasmax_series(tx1 + K2C, start="1/1/2000") hsf = atmos.heat_spell_total_length( - tn, tx, thresh_tasmin="22 C", thresh_tasmax="30 C", freq="YS", op=">" + tn, tx, thresh_tasmin="22.1 C", thresh_tasmax="30.1 C", freq="YS" ) np.testing.assert_allclose(hsf.values[:1], 7) hsf = atmos.heat_spell_total_length( - tn, tx, thresh_tasmin="20 C", thresh_tasmax="30 C", window=5, freq="YS" + tn, tx, thresh_tasmin="22 C", thresh_tasmax="30 C", window=5, freq="YS" ) np.testing.assert_allclose(hsf.values[:1], 5) diff --git a/xclim/indicators/atmos/_temperature.py b/xclim/indicators/atmos/_temperature.py index ae5e5122b..b4dc5a041 100644 --- a/xclim/indicators/atmos/_temperature.py +++ b/xclim/indicators/atmos/_temperature.py @@ -279,10 +279,13 @@ class TempWithIndexing(ResamplingIndicatorWithIndexing): compute=indices.generic.bivariate_spell_length_statistics, input={"data1": "tasmin", "data2": "tasmax"}, parameters=dict( - win_reducer="mean", spell_reducer="count", + op=">=", window={"default": 3}, - op={"default": ">="}, + win_reducer={ + "default": "mean", + "name": "op", + }, # To mirror other spell indicators freq={"default": "YS"}, threshold1={ "description": "Threshold for tasmin", @@ -311,10 +314,13 @@ class TempWithIndexing(ResamplingIndicatorWithIndexing): compute=indices.generic.bivariate_spell_length_statistics, input={"data1": "tasmin", "data2": "tasmax"}, parameters=dict( - win_reducer="mean", spell_reducer="max", + op=">=", window={"default": 3}, - op={"default": ">="}, + win_reducer={ + "default": "mean", + "name": "op", + }, # To mirror other spell indicators freq={"default": "YS"}, threshold1={ "description": "Threshold for tasmin", @@ -343,10 +349,13 @@ class TempWithIndexing(ResamplingIndicatorWithIndexing): compute=indices.generic.bivariate_spell_length_statistics, input={"data1": "tasmin", "data2": "tasmax"}, parameters=dict( - win_reducer="mean", - spell_reducer="max", + spell_reducer="sum", + op=">=", window={"default": 3}, - op={"default": ">="}, + win_reducer={ + "default": "mean", + "name": "op", + }, # To mirror other spell indicators freq={"default": "YS"}, threshold1={ "description": "Threshold for tasmin", From f5bbcdd759b9efa787f0a813b94a52a3ddc79e9b Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Thu, 15 Aug 2024 11:54:35 -0400 Subject: [PATCH 05/14] add numbers to changes --- CHANGELOG.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index a850d5983..464e6a8ae 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -8,12 +8,12 @@ Contributors to this version: Pascal Bourgault (:user:`aulemahal`). New indicators ^^^^^^^^^^^^^^ -* New ``heat_spell_frequency``, ``heat_spell_max_length`` and ``heat_spell_total_length`` : spell length statistics on a bivariate condition that uses the average over a window by default. +* New ``heat_spell_frequency``, ``heat_spell_max_length`` and ``heat_spell_total_length`` : spell length statistics on a bivariate condition that uses the average over a window by default. (:pull:`1885`). New features and enhancements ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -* New generic ``xclim.indices.generic.spell_mask`` that returns a mask of which days are part of a spell. Supports multivariate conditions and weights. Used in new generic index ``xclim.indices.generic.bivariate_spell_length_statistics`` that extends ``spell_length_statistics`` to two variables. -* Indicator parameters can now be assigned a new name that is used in the indicator's signature and docstring. +* New generic ``xclim.indices.generic.spell_mask`` that returns a mask of which days are part of a spell. Supports multivariate conditions and weights. Used in new generic index ``xclim.indices.generic.bivariate_spell_length_statistics`` that extends ``spell_length_statistics`` to two variables. (:pull:`1885`). +* Indicator parameters can now be assigned a new name that is used in the indicator's signature and docstring. (:pull:`1885`). v0.52.0 (2024-08-08) -------------------- From 525156d8faaad612d705d9ab934ee0bcaee395a1 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Thu, 15 Aug 2024 15:19:12 -0400 Subject: [PATCH 06/14] Fix for sphinx - remove superfluous words --- xclim/core/indicator.py | 2 +- xclim/indicators/atmos/_temperature.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/xclim/core/indicator.py b/xclim/core/indicator.py index 75fc6a3ee..9115561d8 100644 --- a/xclim/core/indicator.py +++ b/xclim/core/indicator.py @@ -196,9 +196,9 @@ class Parameter: _empty = _empty kind: InputKind - name: str = _empty default: Any = _empty_default description: str = "" + name: str = _empty units: str = _empty choices: set = _empty value: Any = _empty diff --git a/xclim/indicators/atmos/_temperature.py b/xclim/indicators/atmos/_temperature.py index b4dc5a041..452bc4724 100644 --- a/xclim/indicators/atmos/_temperature.py +++ b/xclim/indicators/atmos/_temperature.py @@ -269,7 +269,7 @@ class TempWithIndexing(ResamplingIndicatorWithIndexing): identifier="heat_spell_frequency", units="", long_name="Number of heat spells", - description="{freq} number of heat spells events within a given period. A heat spell occurs when the {window}-day " + description="{freq} number of heat spells events. A heat spell occurs when the {window}-day " "averages of daily minimum and maximum temperatures each exceed {threshold1} and {threshold2}. " "All days of the {window}-day period are considered part of the spell.", abstract="Number of heat spells. A heat spell occurs when rolling averages of daily minimum and maximum temperatures exceed given " @@ -306,7 +306,7 @@ class TempWithIndexing(ResamplingIndicatorWithIndexing): units="days", standard_name="spell_length_of_days_with_air_temperature_above_threshold", long_name="Longest heat spell", - description="{freq} maximum length of heat spells within a given period. A heat spell occurs when the {window}-day " + description="{freq} maximum length of heat spells. A heat spell occurs when the {window}-day " "averages of daily minimum and maximum temperatures each exceed {threshold1} and {threshold2}. " "All days of the {window}-day period are considered part of the spell.", abstract="The longest heat spell of a period. A heat spell occurs when rolling averages of daily minimum and maximum temperatures exceed given " @@ -341,7 +341,7 @@ class TempWithIndexing(ResamplingIndicatorWithIndexing): units="days", standard_name="spell_length_of_days_with_air_temperature_above_threshold", long_name="Total length of heat spells.", - description="{freq} total length of heat spell events occurring within a given period. " + description="{freq} total length of heat spell events. " "A heat spell occurs when the {window}-day averages of daily minimum and maximum temperatures " "each exceed {threshold1} and {threshold2}. All days of the {window}-day period are considered part of the spell.", abstract="Total length of heat spells. A heat spell occurs when rolling averages of daily minimum and maximum temperatures exceed given " From 4acea8b8d30593de99660e6ea6097c607a557e3b Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Thu, 15 Aug 2024 17:55:19 -0400 Subject: [PATCH 07/14] Renaming only for sig didnt work - full parameter renaming --- CHANGELOG.rst | 6 +- xclim/core/formatting.py | 2 - xclim/core/indicator.py | 119 +++++++++++-------------- xclim/indicators/atmos/_temperature.py | 21 ++--- 4 files changed, 64 insertions(+), 84 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 464e6a8ae..1e181b7ac 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -13,7 +13,11 @@ New indicators New features and enhancements ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ * New generic ``xclim.indices.generic.spell_mask`` that returns a mask of which days are part of a spell. Supports multivariate conditions and weights. Used in new generic index ``xclim.indices.generic.bivariate_spell_length_statistics`` that extends ``spell_length_statistics`` to two variables. (:pull:`1885`). -* Indicator parameters can now be assigned a new name that is used in the indicator's signature and docstring. (:pull:`1885`). +* Indicator parameters can now be assigned a new name, different from the argument name in the compute function. (:pull:`1885`). + +Internal changes +^^^^^^^^^^^^^^^^ +* ``xclim.core.indicator.Parameter`` has a new attribute ``compute_name`` while ``xclim.core.indicator.Indicator`` lost its ``_variable_mapping``. The translation from parameter (and variable) names in the indicator to the names on the compute function is handled by ``Indicator._get_compute_args``. (:pull:`1885`). v0.52.0 (2024-08-08) -------------------- diff --git a/xclim/core/formatting.py b/xclim/core/formatting.py index e9883ec19..aaa18025b 100644 --- a/xclim/core/formatting.py +++ b/xclim/core/formatting.py @@ -582,8 +582,6 @@ def _gen_parameters_section( """ section = "Parameters\n----------\n" for name, param in parameters.items(): - if param.name is not param._empty: - name = param.name desc_str = param.description if param.kind == InputKind.FREQ_STR: desc_str += ( diff --git a/xclim/core/indicator.py b/xclim/core/indicator.py index 9115561d8..6c6381fb4 100644 --- a/xclim/core/indicator.py +++ b/xclim/core/indicator.py @@ -78,7 +78,7 @@ units: # Only valid if "compute" points to a generic function default : description: - name : # A name to use when generating the docstring and signature. + name : # Change the name of the parameter (similar to what `input` does for variables) kind: # Override the parameter kind. # This is mostly useful for transforming an optional variable into a required one by passing ``kind: 0``. ... @@ -196,9 +196,11 @@ class Parameter: _empty = _empty kind: InputKind + compute_name: str = ( + _empty # Name of the compute function's argument corresponding to this parameter. + ) default: Any = _empty_default description: str = "" - name: str = _empty units: str = _empty choices: set = _empty value: Any = _empty @@ -213,18 +215,13 @@ def update(self, other: dict) -> None: @classmethod def is_parameter_dict(cls, other: dict) -> bool: - """Return whether indicator has a parameter dictionary.""" + """Return whether other can update a parameter dictionary.""" + # Passing compute_name is forbidden. + # name is valid, but is handled by the indicator return set(other.keys()).issubset( - cls.__dataclass_fields__.keys() # pylint: disable=no-member + {"kind", "default", "description", "units", "choices", "value", "name"} ) - # def __getitem__(self, key) -> str: - # """Return an item in retro-compatible fashion.""" - # try: - # return str(getattr(self, key)) - # except AttributeError as err: - # raise KeyError(key) from err - def __contains__(self, key) -> bool: """Imitate previous behaviour where "units" and "choices" were missing, instead of being "_empty".""" return getattr(self, key, _empty) is not _empty @@ -353,10 +350,8 @@ class Indicator(IndicatorRegistrar): # metadata fields that are formatted as free text (first letter capitalized) _text_fields = ["long_name", "description", "comment"] - # Class attributes that are function (so we know which to convert to static methods) + # Class attributes that are functions (so we know which to convert to static methods) _funcs = ["compute"] - # Mapping from name in the compute function to official (CMIP6) variable name - _variable_mapping = {} # Will become the class's name identifier = None @@ -438,33 +433,19 @@ def __new__(cls, **kwds): # noqa: C901 else: # inherit parameters from base class parameters = deepcopy(cls._all_parameters) - # Update parameters with passed parameters + # Update parameters with passed parameters, might change some parameters name (but not variables) cls._update_parameters(parameters, kwds.pop("parameters", {})) # Input variable mapping (to change variable names in signature and expected units/cf attrs). - cls._parse_var_mapping(kwds.pop("input", {}), parameters, kwds) + # new_units is a mapping from compute function name to units inferred from the var mapping + new_units = cls._parse_var_mapping(kwds.pop("input", {}), parameters) # Raise on incorrect params, sort params, modify var defaults in-place if needed parameters = cls._ensure_correct_parameters(parameters) - # If needed, wrap compute with declare units - if "compute" in kwds: - if not hasattr(kwds["compute"], "in_units") and "_variable_mapping" in kwds: - # We actually need the inverse mapping (to get cmip6 name -> arg name) - inv_var_map = dict(map(reversed, kwds["_variable_mapping"].items())) - # parameters has already been update above. - kwds["compute"] = declare_units( - **{ - inv_var_map[k]: m.units - for k, m in parameters.items() - if "units" in m and k in inv_var_map - } - )(kwds["compute"]) - - if hasattr(kwds["compute"], "in_units"): - varmap = kwds.get("_variable_mapping", {}) - for name, unit in kwds["compute"].in_units.items(): - parameters[varmap.get(name, name)].units = unit + if "compute" in kwds and not hasattr(kwds["compute"], "in_units") and new_units: + # If needed, wrap compute with declare units + kwds["compute"] = declare_units(**new_units)(kwds["compute"]) # All updates done. kwds["_all_parameters"] = parameters @@ -535,8 +516,11 @@ def _parse_indice(compute, passed_parameters): # noqa: F841 ) for name, param in compute_sig.parameters.items(): meta = params_dict.setdefault(name, {}) + meta["compute_name"] = name meta["default"] = param.default meta["kind"] = infer_kind_from_parameter(param) + if hasattr(compute, "in_units") and name in compute.in_units: + meta["units"] = compute.in_units[name] parameters = {name: Parameter(**param) for name, param in params_dict.items()} return parameters, docmeta @@ -561,7 +545,14 @@ def _update_parameters(cls, parameters, passed): try: for key, val in passed.items(): if isinstance(val, dict) and Parameter.is_parameter_dict(val): - # modified meta + if "name" in val: + new_key = val.pop("name") + if new_key in parameters: + raise ValueError( + f"Cannot rename a parameter or variable with the same name as another parameter. '{new_key}' is already a parameter." + ) + parameters[new_key] = parameters.pop(key) + key = new_key parameters[key].update(val) elif key in parameters: parameters[key].value = val @@ -574,9 +565,10 @@ def _update_parameters(cls, parameters, passed): ) from err @classmethod - def _parse_var_mapping(cls, variable_mapping, parameters, kwds): + def _parse_var_mapping(cls, variable_mapping, parameters): """Parse the variable mapping passed in `input` and update `parameters` in-place.""" # Update parameters + new_units = {} for old_name, new_name in variable_mapping.items(): meta = parameters[new_name] = parameters.pop(old_name) try: @@ -597,13 +589,9 @@ def _parse_var_mapping(cls, variable_mapping, parameters, kwds): f"{meta.units}, new = {varmeta['canonical_units']}" ) from err meta.units = varmeta.get("dimensions", varmeta["canonical_units"]) + new_units[meta.compute_name] = meta.units meta.description = varmeta["description"] - - if variable_mapping: - # Update mapping attribute - new_variable_mapping = deepcopy(cls._variable_mapping) - new_variable_mapping.update(variable_mapping) - kwds["_variable_mapping"] = new_variable_mapping + return new_units @classmethod def _ensure_correct_parameters(cls, parameters): @@ -790,13 +778,12 @@ def _gen_signature(self): ) ) else: - show_name = name if meta.name is meta._empty else meta.name parameters.append( _Parameter( - show_name, + name, kind=_Parameter.KEYWORD_ONLY, default=meta.default, - annotation=compute_sig.parameters[name].annotation, + annotation=compute_sig.parameters[meta.compute_name].annotation, ) ) @@ -828,22 +815,10 @@ def __call__(self, *args, **kwds): das, params = self._preprocess_and_checks(das, params) - # Get correct variable names for the compute function. - inv_var_map = dict(map(reversed, self._variable_mapping.items())) - compute_das = {inv_var_map.get(nm, nm): das[nm] for nm in das} - - # Compute the indicator values, ignoring NaNs and missing values. - # Filter the passed parameters to only keep the ones needed by compute. - kwargs = {} - var_kwargs = {} - for nm, pa in signature(self.compute).parameters.items(): - if pa.kind == _Parameter.VAR_KEYWORD: - var_kwargs = params[nm] - elif nm not in compute_das and nm in params: - kwargs[nm] = params[nm] - + # get mappings where keys are the actual compute function's argument names + compute_das, compute_params, var_kwargs = self._get_compute_args(das, params) with xarray.set_options(keep_attrs=False): - outs = self.compute(**compute_das, **kwargs, **var_kwargs) + outs = self.compute(**compute_das, **compute_params, **var_kwargs) if isinstance(outs, DataArray): outs = [outs] @@ -909,11 +884,6 @@ def _parse_variables_from_call(self, args, kwds) -> tuple[OrderedDict, dict]: ba = self.__signature__.bind(*args, **kwds) ba.apply_defaults() - # Assign parameters with different signature-names correctly - for name, param in self._all_parameters.items(): - if param.name is not param._empty: - ba.arguments[name] = ba.arguments.pop(param.name) - # Assign inputs passed as strings from ds. self._assign_named_args(ba) @@ -934,7 +904,7 @@ def _parse_variables_from_call(self, args, kwds) -> tuple[OrderedDict, dict]: else: params[name] = param.value - ds = ba.arguments.get("ds") + ds = params.pop("ds", None) dsattrs = ds.attrs if ds is not None else {} return das, params, dsattrs @@ -970,6 +940,23 @@ def _preprocess_and_checks(self, das, params): self._bind_call(self.cfcheck, **das) return das, params + def _get_compute_args(self, das, params): + """Rename variables and parameters to match the compute function's names and split VAR_KEYWORD arguments.""" + # Get correct variable names for the compute function. + inv_var_map = { + key: p.compute_name + for key, p in self._all_parameters.items() + if p.compute_name is not _empty + } + compute_das = {inv_var_map[nm]: das[nm] for nm in das} + compute_params = {inv_var_map[nm]: params[nm] for nm in params} + + var_kwargs = {} + for key, p in self._all_parameters.items(): + if p.kind == InputKind.KWARGS and p.compute_name in compute_params: + var_kwargs.update(compute_params.pop(p.compute_name)) + return compute_das, compute_params, var_kwargs + def _postprocess(self, outs, das, params): """Actions to done after computing.""" return outs diff --git a/xclim/indicators/atmos/_temperature.py b/xclim/indicators/atmos/_temperature.py index 452bc4724..2394b6e9f 100644 --- a/xclim/indicators/atmos/_temperature.py +++ b/xclim/indicators/atmos/_temperature.py @@ -270,7 +270,7 @@ class TempWithIndexing(ResamplingIndicatorWithIndexing): units="", long_name="Number of heat spells", description="{freq} number of heat spells events. A heat spell occurs when the {window}-day " - "averages of daily minimum and maximum temperatures each exceed {threshold1} and {threshold2}. " + "averages of daily minimum and maximum temperatures each exceed {thresh_tasmin} and {thresh_tasmax}. " "All days of the {window}-day period are considered part of the spell.", abstract="Number of heat spells. A heat spell occurs when rolling averages of daily minimum and maximum temperatures exceed given " "thresholds for a number of days.", @@ -282,10 +282,7 @@ class TempWithIndexing(ResamplingIndicatorWithIndexing): spell_reducer="count", op=">=", window={"default": 3}, - win_reducer={ - "default": "mean", - "name": "op", - }, # To mirror other spell indicators + win_reducer={"default": "mean"}, freq={"default": "YS"}, threshold1={ "description": "Threshold for tasmin", @@ -307,7 +304,7 @@ class TempWithIndexing(ResamplingIndicatorWithIndexing): standard_name="spell_length_of_days_with_air_temperature_above_threshold", long_name="Longest heat spell", description="{freq} maximum length of heat spells. A heat spell occurs when the {window}-day " - "averages of daily minimum and maximum temperatures each exceed {threshold1} and {threshold2}. " + "averages of daily minimum and maximum temperatures each exceed {thresh_tasmin} and {thresh_tasmax}. " "All days of the {window}-day period are considered part of the spell.", abstract="The longest heat spell of a period. A heat spell occurs when rolling averages of daily minimum and maximum temperatures exceed given " "thresholds for a number of days.", @@ -317,10 +314,7 @@ class TempWithIndexing(ResamplingIndicatorWithIndexing): spell_reducer="max", op=">=", window={"default": 3}, - win_reducer={ - "default": "mean", - "name": "op", - }, # To mirror other spell indicators + win_reducer={"default": "mean"}, freq={"default": "YS"}, threshold1={ "description": "Threshold for tasmin", @@ -343,7 +337,7 @@ class TempWithIndexing(ResamplingIndicatorWithIndexing): long_name="Total length of heat spells.", description="{freq} total length of heat spell events. " "A heat spell occurs when the {window}-day averages of daily minimum and maximum temperatures " - "each exceed {threshold1} and {threshold2}. All days of the {window}-day period are considered part of the spell.", + "each exceed {thresh_tasmin} and {thresh_tasmax}. All days of the {window}-day period are considered part of the spell.", abstract="Total length of heat spells. A heat spell occurs when rolling averages of daily minimum and maximum temperatures exceed given " "thresholds for a number of days.", compute=indices.generic.bivariate_spell_length_statistics, @@ -352,10 +346,7 @@ class TempWithIndexing(ResamplingIndicatorWithIndexing): spell_reducer="sum", op=">=", window={"default": 3}, - win_reducer={ - "default": "mean", - "name": "op", - }, # To mirror other spell indicators + win_reducer={"default": "mean"}, freq={"default": "YS"}, threshold1={ "description": "Threshold for tasmin", From 041b94282673590093c3c8eed3e909b31dd84e2c Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Thu, 15 Aug 2024 18:20:15 -0400 Subject: [PATCH 08/14] Fix for indexer in IndexingIndicator --- tests/test_indicators.py | 3 +-- xclim/core/indicator.py | 49 ++++++++++++++++++---------------------- xclim/data/fr.json | 6 ++--- 3 files changed, 26 insertions(+), 32 deletions(-) diff --git a/tests/test_indicators.py b/tests/test_indicators.py index 4ca95ea2f..41fd31bd3 100644 --- a/tests/test_indicators.py +++ b/tests/test_indicators.py @@ -708,8 +708,7 @@ def test_indicator_from_dict(): assert ind.parameters["threshold"].description == "A threshold temp" # Injection of parameters assert ind.injected_parameters["op"] == "<" - # Default value for input variable injected and meta injected - assert ind._variable_mapping["data"] == "tas" + assert ind.parameters["tas"].compute_name == "data" assert signature(ind).parameters["tas"].default == "tas" assert ind.parameters["tas"].units == "[temperature]" diff --git a/xclim/core/indicator.py b/xclim/core/indicator.py index 6c6381fb4..29e712d46 100644 --- a/xclim/core/indicator.py +++ b/xclim/core/indicator.py @@ -904,7 +904,7 @@ def _parse_variables_from_call(self, args, kwds) -> tuple[OrderedDict, dict]: else: params[name] = param.value - ds = params.pop("ds", None) + ds = params.get("ds") dsattrs = ds.attrs if ds is not None else {} return das, params, dsattrs @@ -943,18 +943,18 @@ def _preprocess_and_checks(self, das, params): def _get_compute_args(self, das, params): """Rename variables and parameters to match the compute function's names and split VAR_KEYWORD arguments.""" # Get correct variable names for the compute function. - inv_var_map = { - key: p.compute_name - for key, p in self._all_parameters.items() - if p.compute_name is not _empty - } - compute_das = {inv_var_map[nm]: das[nm] for nm in das} - compute_params = {inv_var_map[nm]: params[nm] for nm in params} - - var_kwargs = {} + # Exclude param without a mapping inside the compute functions (those injected by the indicator class) + compute_das, compute_params, var_kwargs = {}, {}, {} for key, p in self._all_parameters.items(): - if p.kind == InputKind.KWARGS and p.compute_name in compute_params: - var_kwargs.update(compute_params.pop(p.compute_name)) + if p.compute_name is not _empty: + if key in das: + compute_das[p.compute_name] = das[key] + if key in params: + if p.kind == InputKind.KWARGS: + var_kwargs.update(params[key]) + else: + compute_params[p.compute_name] = params[key] + return compute_das, compute_params, var_kwargs def _postprocess(self, outs, das, params): @@ -1531,21 +1531,16 @@ def _preprocess_and_checks(self, das, params): class IndexingIndicator(Indicator): """Indicator that also injects "indexer" kwargs to subset the inputs before computation.""" - @classmethod - def _injected_parameters(cls): - return super()._injected_parameters() + [ - ( - "indexer", - Parameter( - kind=InputKind.KWARGS, - description=( - "Indexing parameters to compute the indicator on a temporal " - "subset of the data. It accepts the same arguments as " - ":py:func:`xclim.indices.generic.select_time`." - ), - ), - ) - ] + def __init__(self, *args, **kwargs): + self._all_parameters["indexer"] = Parameter( + kind=InputKind.KWARGS, + description=( + "Indexing parameters to compute the indicator on a temporal " + "subset of the data. It accepts the same arguments as " + ":py:func:`xclim.indices.generic.select_time`." + ), + ) + super().__init__(*args, **kwargs) def _preprocess_and_checks(self, das: dict[str, DataArray], params: dict[str, Any]): """Perform parent's checks and also check if freq is allowed.""" diff --git a/xclim/data/fr.json b/xclim/data/fr.json index 07c2e0933..d11f72253 100644 --- a/xclim/data/fr.json +++ b/xclim/data/fr.json @@ -337,19 +337,19 @@ }, "HEAT_SPELL_FREQUENCY": { "long_name": "Nombre de vagues de chaleur.", - "description": "Nombre {freq:m} de vagues de chaleur. Une vague de chaleur se produit lorsque les moyennes sur {window} jours des températures minimale et maximale quotidiennes excèdent {thresh_tasmin} et {thresh_tasmax}, respectivement. Les {windows} jours de la moyenne sont comptés dans la vague.", + "description": "Nombre {freq:m} de vagues de chaleur. Une vague de chaleur se produit lorsque les moyennes sur {window} jours des températures minimale et maximale quotidiennes excèdent {thresh_tasmin} et {thresh_tasmax}, respectivement. Les {window} jours de la moyenne sont comptés dans la vague.", "title": "Fréquence des vagues de chaleur", "abstract": "Nombre de vagues de chaleur. Une vague de chaleur se produit lorsque les moyennes mobiles des températures minimale et maximale quotidiennes excèdent des seuils donnés." }, "HEAT_SPELL_MAX_LENGTH": { "long_name": "Vague de chaleur la plus longue.", - "description": "Durée maximale {freq:f} des vagues de chaleur. Une vague de chaleur se produit lorsque les moyennes sur {window} jours des températures minimale et maximale quotidiennes excèdent {thresh_tasmin} et {thresh_tasmax}, respectivement. Les {windows} jours de la moyenne sont comptés dans la vague.", + "description": "Durée maximale {freq:f} des vagues de chaleur. Une vague de chaleur se produit lorsque les moyennes sur {window} jours des températures minimale et maximale quotidiennes excèdent {thresh_tasmin} et {thresh_tasmax}, respectivement. Les {window} jours de la moyenne sont comptés dans la vague.", "title": "Durée maximale des vagues de chaleur", "abstract": "Durée maximale des vagues de chaleur. Une vague de chaleur se produit lorsque les moyennes mobiles des températures minimale et maximale quotidiennes excèdent des seuils donnés." }, "HEAT_SPELL_TOTAL_LENGTH": { "long_name": "Durée totale des vagues de chaleur.", - "description": "Durée totale {freq:f} des vagues de chaleur. Une vague de chaleur se produit lorsque les moyennes sur {window} jours des températures minimale et maximale quotidiennes excèdent {thresh_tasmin} et {thresh_tasmax}, respectivement. Les {windows} jours de la moyenne sont comptés dans la vague.", + "description": "Durée totale {freq:f} des vagues de chaleur. Une vague de chaleur se produit lorsque les moyennes sur {window} jours des températures minimale et maximale quotidiennes excèdent {thresh_tasmin} et {thresh_tasmax}, respectivement. Les {window} jours de la moyenne sont comptés dans la vague.", "title": "Durée totale des vagues de chaleur", "abstract": "Durée totale des vagues de chaleur. Une vague de chaleur se produit lorsque les moyennes mobiles des températures minimale et maximale quotidiennes excèdent des seuils donnés." }, From 5e73a6a74169925fae056b1f0afe7edc1541a236 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Thu, 15 Aug 2024 18:48:01 -0400 Subject: [PATCH 09/14] Fix for Percentile DA --- xclim/core/indicator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xclim/core/indicator.py b/xclim/core/indicator.py index 29e712d46..ddcc13c39 100644 --- a/xclim/core/indicator.py +++ b/xclim/core/indicator.py @@ -949,7 +949,8 @@ def _get_compute_args(self, das, params): if p.compute_name is not _empty: if key in das: compute_das[p.compute_name] = das[key] - if key in params: + # elif because some args are in both (percentile DataArrays) + elif key in params: if p.kind == InputKind.KWARGS: var_kwargs.update(params[key]) else: From d472e02c0760390d2fd1cb6f7da1ab93156ae150 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Thu, 15 Aug 2024 18:50:05 -0400 Subject: [PATCH 10/14] simplify --- xclim/core/indicator.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/xclim/core/indicator.py b/xclim/core/indicator.py index ddcc13c39..f6a734ab0 100644 --- a/xclim/core/indicator.py +++ b/xclim/core/indicator.py @@ -816,9 +816,9 @@ def __call__(self, *args, **kwds): das, params = self._preprocess_and_checks(das, params) # get mappings where keys are the actual compute function's argument names - compute_das, compute_params, var_kwargs = self._get_compute_args(das, params) + args = self._get_compute_args(das, params) with xarray.set_options(keep_attrs=False): - outs = self.compute(**compute_das, **compute_params, **var_kwargs) + outs = self.compute(**args) if isinstance(outs, DataArray): outs = [outs] @@ -944,19 +944,19 @@ def _get_compute_args(self, das, params): """Rename variables and parameters to match the compute function's names and split VAR_KEYWORD arguments.""" # Get correct variable names for the compute function. # Exclude param without a mapping inside the compute functions (those injected by the indicator class) - compute_das, compute_params, var_kwargs = {}, {}, {} + args = {} for key, p in self._all_parameters.items(): if p.compute_name is not _empty: if key in das: - compute_das[p.compute_name] = das[key] + args[p.compute_name] = das[key] # elif because some args are in both (percentile DataArrays) elif key in params: if p.kind == InputKind.KWARGS: - var_kwargs.update(params[key]) + args.update(params[key]) else: - compute_params[p.compute_name] = params[key] + args[p.compute_name] = params[key] - return compute_das, compute_params, var_kwargs + return args def _postprocess(self, outs, das, params): """Actions to done after computing.""" From c0016ee2d75d93cc0a50fa6cc3699b133e204cf6 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Thu, 15 Aug 2024 19:11:46 -0400 Subject: [PATCH 11/14] Fix doc of Parameter again --- xclim/core/indicator.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/xclim/core/indicator.py b/xclim/core/indicator.py index f6a734ab0..880dc01cb 100644 --- a/xclim/core/indicator.py +++ b/xclim/core/indicator.py @@ -196,10 +196,9 @@ class Parameter: _empty = _empty kind: InputKind - compute_name: str = ( - _empty # Name of the compute function's argument corresponding to this parameter. - ) default: Any = _empty_default + # Name of the compute function's argument corresponding to this parameter. + compute_name: str = _empty description: str = "" units: str = _empty choices: set = _empty From 3d276a07aad907ef6c7c8c1d57b643540ae7f593 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Sun, 18 Aug 2024 09:55:58 -0400 Subject: [PATCH 12/14] Fix window 1 fast-path --- xclim/indices/generic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xclim/indices/generic.py b/xclim/indices/generic.py index 459103dd6..fdfc55937 100644 --- a/xclim/indices/generic.py +++ b/xclim/indices/generic.py @@ -418,9 +418,9 @@ def spell_mask( weights = xarray.DataArray(weights, dims=("window",)) if window == 1: # Fast path - mask = compare(data, op, thresh) + is_in_spell = compare(data, op, thresh) if not np.isscalar(thresh): - is_in_spell = getattr(mask, var_reducer)("variable") + is_in_spell = getattr(is_in_spell, var_reducer)("variable") elif (win_reducer == "min" and op in [">", ">=", "ge", "gt"]) or ( win_reducer == "max" and op in ["`<", "<=", "le", "lt"] ): From fb9e59b2c537fd60d7b8b0464fae205536e93216 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Thu, 5 Sep 2024 15:41:13 -0400 Subject: [PATCH 13/14] multi reducing --- tests/test_generic.py | 31 ++++++++++++ xclim/indices/generic.py | 107 +++++++++++++++++++++++++-------------- 2 files changed, 101 insertions(+), 37 deletions(-) diff --git a/tests/test_generic.py b/tests/test_generic.py index c716a7d86..0054dd935 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -737,3 +737,34 @@ def test_errors(self): # Weights must have same length as window with pytest.raises(ValueError, match="Weights have a different length"): generic.spell_mask(data, 3, "mean", "<=", 2, weights=[1, 2]) + + +def test_spell_length_statistics_multi(tasmin_series, tasmax_series): + tn = tasmin_series( + np.zeros( + 365, + ) + + 270, + start="2001-01-01", + ) + tx = tasmax_series( + np.zeros( + 365, + ) + + 270, + start="2001-01-01", + ) + + outc, outs, outm = generic.bivariate_spell_length_statistics( + tn, + "0 °C", + tx, + "1°C", + window=5, + win_reducer="min", + op="<", + spell_reducer=["count", "sum", "max"], + freq="YS", + ) + xr.testing.assert_equal(outs, outm) + np.testing.assert_allclose(outc, 1) diff --git a/xclim/indices/generic.py b/xclim/indices/generic.py index 2ad8f156b..74b1ef3d0 100644 --- a/xclim/indices/generic.py +++ b/xclim/indices/generic.py @@ -404,7 +404,10 @@ def spell_mask( "When ``data`` is given as a list, ``threshold`` must be a sequence of the same length." ) data = xarray.concat(data, "variable") - thresh = xarray.DataArray(thresh, dims=("variable",)) + if isinstance(thresh[0], xarray.DataArray): + thresh = xr.concat(thresh, "variable") + else: + thresh = xarray.DataArray(thresh, dims=("variable",)) if weights is not None: if win_reducer != "mean": raise ValueError( @@ -455,6 +458,50 @@ def spell_mask( return is_in_spell +def _spell_length_statistics( + data: xarray.DataArray | Sequence[xarray.DataArray], + thresh: float | xarray.DataArray | Sequence[xarray.DataArray] | Sequence[float], + window: int, + win_reducer: str, + op: str, + spell_reducer: str | Sequence[str], + freq: str, + resample_before_rl: bool = True, + **indexer, +) -> xarray.DataArray | Sequence[xarray.DataArray]: + if isinstance(spell_reducer, str): + spell_reducer = [spell_reducer] + is_in_spell = spell_mask(data, window, win_reducer, op, thresh).astype(np.float32) + is_in_spell = select_time(is_in_spell, **indexer) + + outs = [] + for sr in spell_reducer: + out = rl.resample_and_rl( + is_in_spell, + resample_before_rl, + rl.rle_statistics, + reducer=sr, + # The code above already ensured only spell of the minimum length are selected + window=1, + freq=freq, + ) + + if sr == "count": + outs.append(out.assign_attrs(units="")) + else: + # All other cases are statistics of the number of timesteps + outs.append( + to_agg_units( + out, + data if isinstance(data, xarray.DataArray) else data[0], + "count", + ) + ) + if len(outs) == 1: + return outs[0] + return tuple(outs) + + @declare_relative_units(threshold="") def spell_length_statistics( data: xr.DataArray, @@ -485,8 +532,8 @@ def spell_length_statistics( Note that this does not matter when `window` is 1. op : {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"} Logical operator. Ex: spell_value > thresh. - spell_reducer : {'max', 'sum', 'count'} - Statistic on the spell lengths. + spell_reducer : {'max', 'sum', 'count'} or sequence thereof + Statistic on the spell lengths. If a list, multiple statistics are computed. freq : str Resampling frequency. resample_before_rl : bool @@ -532,24 +579,18 @@ def spell_length_statistics( bivariate_spell_length_statistics : The bivariate version of this function. """ thresh = convert_units_to(threshold, data, context="infer") - is_in_spell = spell_mask(data, window, win_reducer, op, thresh).astype(np.float32) - is_in_spell = select_time(is_in_spell, **indexer) - - out = rl.resample_and_rl( - is_in_spell, + return _spell_length_statistics( + data, + thresh, + window, + win_reducer, + op, + spell_reducer, + freq, resample_before_rl, - rl.rle_statistics, - reducer=spell_reducer, - # The code above already ensured only spell of the minimum length are selected - window=1, - freq=freq, + **indexer, ) - if spell_reducer == "count": - return out.assign_attrs(units="") - # All other cases are statistics of the number of timesteps - return to_agg_units(out, data, "count") - @declare_relative_units(threshold1="", threshold2="") def bivariate_spell_length_statistics( @@ -587,8 +628,8 @@ def bivariate_spell_length_statistics( Note that this does not matter when `window` is 1. op : {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"} Logical operator. Ex: spell_value > thresh. - spell_reducer : {'max', 'sum', 'count'} - Statistic on the spell lengths. + spell_reducer : {'max', 'sum', 'count'} or sequence thereof + Statistic on the spell lengths. If a list, multiple statistics are computed. freq : str Resampling frequency. resample_before_rl : bool @@ -606,26 +647,18 @@ def bivariate_spell_length_statistics( """ thresh1 = convert_units_to(threshold1, data1, context="infer") thresh2 = convert_units_to(threshold2, data2, context="infer") - is_in_spell = spell_mask( - [data1, data2], window, win_reducer, op, [thresh1, thresh2], var_reducer="all" - ).astype(np.float32) - is_in_spell = select_time(is_in_spell, **indexer) - - out = rl.resample_and_rl( - is_in_spell, + return _spell_length_statistics( + [data1, data2], + [thresh1, thresh2], + window, + win_reducer, + op, + spell_reducer, + freq, resample_before_rl, - rl.rle_statistics, - reducer=spell_reducer, - # The code above already ensured only spell of the minimum length are selected - window=1, - freq=freq, + **indexer, ) - if spell_reducer == "count": - return out.assign_attrs(units="") - # All other cases are statistics of the number of timesteps - return to_agg_units(out, data1, "count") - @declare_relative_units(thresh="") def season( From 81d46bf3ea7cb721646cca66edc0201c8c1580b4 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Fri, 6 Sep 2024 13:20:14 -0400 Subject: [PATCH 14/14] xr instead of xarray --- xclim/indices/generic.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/xclim/indices/generic.py b/xclim/indices/generic.py index 74b1ef3d0..43bc85720 100644 --- a/xclim/indices/generic.py +++ b/xclim/indices/generic.py @@ -356,14 +356,14 @@ def get_daily_events( def spell_mask( - data: xarray.DataArray | Sequence[xarray.DataArray], + data: xr.DataArray | Sequence[xr.DataArray], window: int, win_reducer: str, op: str, thresh: float | Sequence[float], weights: Sequence[float] = None, var_reducer: str = "all", -) -> xarray.DataArray: +) -> xr.DataArray: """Compute the boolean mask of data points that are part of a spell as defined by a rolling statistic. A day is part of a spell (True in the mask) if it is contained in any period that fulfills the condition. @@ -392,22 +392,22 @@ def spell_mask( Returns ------- - xarray.DataArray + xr.DataArray Same shape as ``data``, but boolean. If ``data`` was a list, this is a DataArray of the same shape as the alignment of all variables. """ # Checks - if not isinstance(data, xarray.DataArray): + if not isinstance(data, xr.DataArray): # thus a sequence if np.isscalar(thresh) or len(data) != len(thresh): raise ValueError( "When ``data`` is given as a list, ``threshold`` must be a sequence of the same length." ) - data = xarray.concat(data, "variable") - if isinstance(thresh[0], xarray.DataArray): + data = xr.concat(data, "variable") + if isinstance(thresh[0], xr.DataArray): thresh = xr.concat(thresh, "variable") else: - thresh = xarray.DataArray(thresh, dims=("variable",)) + thresh = xr.DataArray(thresh, dims=("variable",)) if weights is not None: if win_reducer != "mean": raise ValueError( @@ -417,7 +417,7 @@ def spell_mask( raise ValueError( f"Weights have a different length ({len(weights)}) than the window ({window})." ) - weights = xarray.DataArray(weights, dims=("window",)) + weights = xr.DataArray(weights, dims=("window",)) if window == 1: # Fast path is_in_spell = compare(data, op, thresh) @@ -459,8 +459,8 @@ def spell_mask( def _spell_length_statistics( - data: xarray.DataArray | Sequence[xarray.DataArray], - thresh: float | xarray.DataArray | Sequence[xarray.DataArray] | Sequence[float], + data: xr.DataArray | Sequence[xr.DataArray], + thresh: float | xr.DataArray | Sequence[xr.DataArray] | Sequence[float], window: int, win_reducer: str, op: str, @@ -468,7 +468,7 @@ def _spell_length_statistics( freq: str, resample_before_rl: bool = True, **indexer, -) -> xarray.DataArray | Sequence[xarray.DataArray]: +) -> xr.DataArray | Sequence[xr.DataArray]: if isinstance(spell_reducer, str): spell_reducer = [spell_reducer] is_in_spell = spell_mask(data, window, win_reducer, op, thresh).astype(np.float32) @@ -493,7 +493,7 @@ def _spell_length_statistics( outs.append( to_agg_units( out, - data if isinstance(data, xarray.DataArray) else data[0], + data if isinstance(data, xr.DataArray) else data[0], "count", ) ) @@ -594,9 +594,9 @@ def spell_length_statistics( @declare_relative_units(threshold1="", threshold2="") def bivariate_spell_length_statistics( - data1: xarray.DataArray, + data1: xr.DataArray, threshold1: Quantified, - data2: xarray.DataArray, + data2: xr.DataArray, threshold2: Quantified, window: int, win_reducer: str,