From e129ac1d7442e6cb6f671aafc58509a167514cc4 Mon Sep 17 00:00:00 2001 From: omiralles Date: Wed, 4 Dec 2024 22:12:00 +0100 Subject: [PATCH 1/4] Lower bound delta lat in power spectrum plot and align input color map for precip plots --- .../config/diagnostics/plot/detailed.yaml | 1 + .../training/diagnostics/callbacks/plot.py | 3 + src/anemoi/training/diagnostics/plots.py | 64 +++++++++++++++++-- 3 files changed, 63 insertions(+), 5 deletions(-) diff --git a/src/anemoi/training/config/diagnostics/plot/detailed.yaml b/src/anemoi/training/config/diagnostics/plot/detailed.yaml index d1ac8b0f..f7d4e78e 100644 --- a/src/anemoi/training/config/diagnostics/plot/detailed.yaml +++ b/src/anemoi/training/config/diagnostics/plot/detailed.yaml @@ -44,6 +44,7 @@ callbacks: - _target_: anemoi.training.diagnostics.callbacks.plot.PlotSpectrum # every_n_batches: 100 # Override for batch frequency + # min_delta: 0.01 # Minimum distance between two consecutive points sample_idx: ${diagnostics.plot.sample_idx} parameters: - z_500 diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index 197a401e..3adcb55a 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -1018,6 +1018,7 @@ def __init__( config: OmegaConf, sample_idx: int, parameters: list[str], + min_delta: float | None = None, every_n_batches: int | None = None, ) -> None: """Initialise the PlotSpectrum callback. @@ -1036,6 +1037,7 @@ def __init__( super().__init__(config, every_n_batches=every_n_batches) self.sample_idx = sample_idx self.parameters = parameters + self.min_delta = min_delta @rank_zero_only def _plot( @@ -1070,6 +1072,7 @@ def _plot( data[0, ...].squeeze(), data[rollout_step + 1, ...].squeeze(), output_tensor[rollout_step, ...], + min_delta=self.min_delta, ) self._output_figure( diff --git a/src/anemoi/training/diagnostics/plots.py b/src/anemoi/training/diagnostics/plots.py index 45818b69..83e5fbac 100644 --- a/src/anemoi/training/diagnostics/plots.py +++ b/src/anemoi/training/diagnostics/plots.py @@ -138,6 +138,7 @@ def plot_power_spectrum( x: np.ndarray, y_true: np.ndarray, y_pred: np.ndarray, + min_delta: float | None = None, ) -> Figure: """Plots power spectrum. @@ -163,6 +164,7 @@ def plot_power_spectrum( The figure object handle. """ + min_delta = min_delta or 0.0003 n_plots_x, n_plots_y = len(parameters), 1 figsize = (n_plots_y * 4, n_plots_x * 3) @@ -177,9 +179,16 @@ def plot_power_spectrum( # Calculate delta_lat on the projected grid delta_lat = abs(np.diff(pc_lat)) non_zero_delta_lat = delta_lat[delta_lat != 0] + min_delta_lat = np.min(abs(non_zero_delta_lat)) + + if min_delta_lat < min_delta: + logging.warning( + f"Minimum distance between lat/lon points is less than the specified minimum distance. Defaulting to min_delta={min_delta}." + ) + min_delta_lat = min_delta # Define a regular grid for interpolation - n_pix_lat = int(np.floor(abs(pc_lat.max() - pc_lat.min()) / abs(np.min(non_zero_delta_lat)))) + n_pix_lat = int(np.floor(abs(pc_lat.max() - pc_lat.min()) / min_delta_lat)) n_pix_lon = (n_pix_lat - 1) * 2 + 1 # 2*lmax + 1 regular_pc_lon = np.linspace(pc_lon.min(), pc_lon.max(), n_pix_lon) regular_pc_lat = np.linspace(pc_lat.min(), pc_lat.max(), n_pix_lat) @@ -313,14 +322,14 @@ def plot_histogram( # enforce the same binning for both histograms bin_min = min(np.nanmin(yt_xt), np.nanmin(yp_xt)) bin_max = max(np.nanmax(yt_xt), np.nanmax(yp_xt)) - hist_yt, bins_yt = np.histogram(yt_xt[~np.isnan(yt_xt)], bins=100, range=[bin_min, bin_max]) - hist_yp, bins_yp = np.histogram(yp_xt[~np.isnan(yp_xt)], bins=100, range=[bin_min, bin_max]) + hist_yt, bins_yt = np.histogram(yt_xt[~np.isnan(yt_xt)], bins=100, density=True, range=[bin_min, bin_max]) + hist_yp, bins_yp = np.histogram(yp_xt[~np.isnan(yp_xt)], bins=100, density=True, range=[bin_min, bin_max]) else: # enforce the same binning for both histograms bin_min = min(np.nanmin(yt), np.nanmin(yp)) bin_max = max(np.nanmax(yt), np.nanmax(yp)) - hist_yt, bins_yt = np.histogram(yt[~np.isnan(yt)], bins=100, range=[bin_min, bin_max]) - hist_yp, bins_yp = np.histogram(yp[~np.isnan(yp)], bins=100, range=[bin_min, bin_max]) + hist_yt, bins_yt = np.histogram(yt[~np.isnan(yt)], bins=100, density=True, range=[bin_min, bin_max]) + hist_yp, bins_yp = np.histogram(yp[~np.isnan(yp)], bins=100, density=True, range=[bin_min, bin_max]) # Visualization trick for tp if variable_name in precip_and_related_fields: @@ -623,6 +632,51 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: title=f"{vname} persist err: {np.nanmean(np.abs(err_plot)):.{4}f} deg.", datashader=datashader, ) + elif vname in precip_and_related_fields: + # Create a custom colormap for precipitation + nws_precip_colors = cmap_precip + precip_colormap = ListedColormap(nws_precip_colors) + + # Defining the actual precipitation accumulation levels in mm + cummulation_lvls = clevels + norm = BoundaryNorm(cummulation_lvls, len(cummulation_lvls) + 1) + + # converting to mm from m + input_ *= 1000.0 + truth *= 1000.0 + pred *= 1000.0 + single_plot( + fig, + ax[0], + lon=lon, + lat=lat, + data=input_, + cmap=precip_colormap, + title=f"{vname} input", + datashader=datashader, + ) + single_plot( + fig, + ax[4], + lon=lon, + lat=lat, + data=pred - input_, + cmap="bwr", + norm=TwoSlopeNorm(vcenter=0.0), + title=f"{vname} increment [pred - input]", + datashader=datashader, + ) + single_plot( + fig, + ax[5], + lon=lon, + lat=lat, + data=truth - input_, + cmap="bwr", + norm=TwoSlopeNorm(vcenter=0.0), + title=f"{vname} persist err", + datashader=datashader, + ) else: single_plot(fig, ax[0], lon, lat, input_, norm=norm, title=f"{vname} input", datashader=datashader) single_plot( From f56f5ce6018894442c284d8f312fa6252d64a842 Mon Sep 17 00:00:00 2001 From: omiralles Date: Mon, 9 Dec 2024 14:13:52 +0100 Subject: [PATCH 2/4] Call LOGGER instead of logging for warning --- src/anemoi/training/diagnostics/plots.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/anemoi/training/diagnostics/plots.py b/src/anemoi/training/diagnostics/plots.py index 83e5fbac..8fb7b273 100644 --- a/src/anemoi/training/diagnostics/plots.py +++ b/src/anemoi/training/diagnostics/plots.py @@ -182,8 +182,9 @@ def plot_power_spectrum( min_delta_lat = np.min(abs(non_zero_delta_lat)) if min_delta_lat < min_delta: - logging.warning( - f"Minimum distance between lat/lon points is less than the specified minimum distance. Defaulting to min_delta={min_delta}." + LOGGER.warning( + "Minimum distance between lat/lon points is less than the specified minimum distance. Defaulting to min_delta=%s.", + min_delta, ) min_delta_lat = min_delta From bc68ee59b003081de1c105fe476542530e7f7613 Mon Sep 17 00:00:00 2001 From: omiralles Date: Mon, 9 Dec 2024 14:53:08 +0100 Subject: [PATCH 3/4] Update doc --- src/anemoi/training/diagnostics/plots.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/anemoi/training/diagnostics/plots.py b/src/anemoi/training/diagnostics/plots.py index 8fb7b273..cd99ce37 100644 --- a/src/anemoi/training/diagnostics/plots.py +++ b/src/anemoi/training/diagnostics/plots.py @@ -157,6 +157,8 @@ def plot_power_spectrum( Expected data of shape (lat*lon, nvar*level) y_pred : np.ndarray Predicted data of shape (lat*lon, nvar*level) + min_delta: float, optional + Minimum distance between lat/lon points, if None defaulted to 1km Returns ------- @@ -183,7 +185,7 @@ def plot_power_spectrum( if min_delta_lat < min_delta: LOGGER.warning( - "Minimum distance between lat/lon points is less than the specified minimum distance. Defaulting to min_delta=%s.", + "Min. distance between lat/lon points is < specified minimum distance. Defaulting to min_delta=%s.", min_delta, ) min_delta_lat = min_delta From 80e5c0dfc06384f0346aa5add24b837fccde0882 Mon Sep 17 00:00:00 2001 From: omiralles Date: Mon, 9 Dec 2024 14:59:04 +0100 Subject: [PATCH 4/4] Changelog update --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f5e3fbc3..3c4e0466 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ Keep it human-readable, your future self will thank you! - Not update NaN-weight-mask for loss function when using remapper and no imputer [#178](https://github.com/ecmwf/anemoi-training/pull/178) - Dont crash when using the profiler if certain env vars arent set [#180](https://github.com/ecmwf/anemoi-training/pull/180) - Remove saving of metadata to training checkpoint [#57](https://github.com/ecmwf/anemoi-training/pull/190) +- Fixes to callback plots [#182] (power spectrum large numpy array error + precip cmap for cases where precip is prognostic). ### Added - Introduce variable to configure: transfer_learning -> bool, True if loading checkpoint in a transfer learning setting.