Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Fixes to callback plots #182

Merged
merged 5 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Keep it human-readable, your future self will thank you!
- Not update NaN-weight-mask for loss function when using remapper and no imputer [#178](https://github.com/ecmwf/anemoi-training/pull/178)
- Dont crash when using the profiler if certain env vars arent set [#180](https://github.com/ecmwf/anemoi-training/pull/180)
- Remove saving of metadata to training checkpoint [#57](https://github.com/ecmwf/anemoi-training/pull/190)
- Fixes to callback plots [#182] (power spectrum large numpy array error + precip cmap for cases where precip is prognostic).

### Added
- Introduce variable to configure: transfer_learning -> bool, True if loading checkpoint in a transfer learning setting.
Expand Down
1 change: 1 addition & 0 deletions src/anemoi/training/config/diagnostics/plot/detailed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ callbacks:

- _target_: anemoi.training.diagnostics.callbacks.plot.PlotSpectrum
# every_n_batches: 100 # Override for batch frequency
# min_delta: 0.01 # Minimum distance between two consecutive points
gabrieloks marked this conversation as resolved.
Show resolved Hide resolved
sample_idx: ${diagnostics.plot.sample_idx}
parameters:
- z_500
Expand Down
3 changes: 3 additions & 0 deletions src/anemoi/training/diagnostics/callbacks/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,7 @@ def __init__(
config: OmegaConf,
sample_idx: int,
parameters: list[str],
min_delta: float | None = None,
every_n_batches: int | None = None,
) -> None:
"""Initialise the PlotSpectrum callback.
Expand All @@ -1036,6 +1037,7 @@ def __init__(
super().__init__(config, every_n_batches=every_n_batches)
self.sample_idx = sample_idx
self.parameters = parameters
self.min_delta = min_delta

@rank_zero_only
def _plot(
Expand Down Expand Up @@ -1070,6 +1072,7 @@ def _plot(
data[0, ...].squeeze(),
data[rollout_step + 1, ...].squeeze(),
output_tensor[rollout_step, ...],
min_delta=self.min_delta,
)

self._output_figure(
Expand Down
67 changes: 62 additions & 5 deletions src/anemoi/training/diagnostics/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def plot_power_spectrum(
x: np.ndarray,
y_true: np.ndarray,
y_pred: np.ndarray,
min_delta: float | None = None,
) -> Figure:
"""Plots power spectrum.

Expand All @@ -156,13 +157,16 @@ def plot_power_spectrum(
Expected data of shape (lat*lon, nvar*level)
y_pred : np.ndarray
Predicted data of shape (lat*lon, nvar*level)
min_delta: float, optional
Minimum distance between lat/lon points, if None defaulted to 1km

Returns
-------
Figure
The figure object handle.

"""
min_delta = min_delta or 0.0003
gabrieloks marked this conversation as resolved.
Show resolved Hide resolved
n_plots_x, n_plots_y = len(parameters), 1

figsize = (n_plots_y * 4, n_plots_x * 3)
Expand All @@ -177,9 +181,17 @@ def plot_power_spectrum(
# Calculate delta_lat on the projected grid
delta_lat = abs(np.diff(pc_lat))
non_zero_delta_lat = delta_lat[delta_lat != 0]
min_delta_lat = np.min(abs(non_zero_delta_lat))

if min_delta_lat < min_delta:
LOGGER.warning(
"Min. distance between lat/lon points is < specified minimum distance. Defaulting to min_delta=%s.",
min_delta,
)
min_delta_lat = min_delta

# Define a regular grid for interpolation
n_pix_lat = int(np.floor(abs(pc_lat.max() - pc_lat.min()) / abs(np.min(non_zero_delta_lat))))
n_pix_lat = int(np.floor(abs(pc_lat.max() - pc_lat.min()) / min_delta_lat))
n_pix_lon = (n_pix_lat - 1) * 2 + 1 # 2*lmax + 1
regular_pc_lon = np.linspace(pc_lon.min(), pc_lon.max(), n_pix_lon)
regular_pc_lat = np.linspace(pc_lat.min(), pc_lat.max(), n_pix_lat)
Expand Down Expand Up @@ -313,14 +325,14 @@ def plot_histogram(
# enforce the same binning for both histograms
bin_min = min(np.nanmin(yt_xt), np.nanmin(yp_xt))
bin_max = max(np.nanmax(yt_xt), np.nanmax(yp_xt))
hist_yt, bins_yt = np.histogram(yt_xt[~np.isnan(yt_xt)], bins=100, range=[bin_min, bin_max])
hist_yp, bins_yp = np.histogram(yp_xt[~np.isnan(yp_xt)], bins=100, range=[bin_min, bin_max])
hist_yt, bins_yt = np.histogram(yt_xt[~np.isnan(yt_xt)], bins=100, density=True, range=[bin_min, bin_max])
hist_yp, bins_yp = np.histogram(yp_xt[~np.isnan(yp_xt)], bins=100, density=True, range=[bin_min, bin_max])
else:
# enforce the same binning for both histograms
bin_min = min(np.nanmin(yt), np.nanmin(yp))
bin_max = max(np.nanmax(yt), np.nanmax(yp))
hist_yt, bins_yt = np.histogram(yt[~np.isnan(yt)], bins=100, range=[bin_min, bin_max])
hist_yp, bins_yp = np.histogram(yp[~np.isnan(yp)], bins=100, range=[bin_min, bin_max])
hist_yt, bins_yt = np.histogram(yt[~np.isnan(yt)], bins=100, density=True, range=[bin_min, bin_max])
hist_yp, bins_yp = np.histogram(yp[~np.isnan(yp)], bins=100, density=True, range=[bin_min, bin_max])

# Visualization trick for tp
if variable_name in precip_and_related_fields:
Expand Down Expand Up @@ -623,6 +635,51 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray:
title=f"{vname} persist err: {np.nanmean(np.abs(err_plot)):.{4}f} deg.",
datashader=datashader,
)
elif vname in precip_and_related_fields:
# Create a custom colormap for precipitation
nws_precip_colors = cmap_precip
precip_colormap = ListedColormap(nws_precip_colors)

# Defining the actual precipitation accumulation levels in mm
cummulation_lvls = clevels
norm = BoundaryNorm(cummulation_lvls, len(cummulation_lvls) + 1)

# converting to mm from m
input_ *= 1000.0
truth *= 1000.0
pred *= 1000.0
single_plot(
fig,
ax[0],
lon=lon,
lat=lat,
data=input_,
cmap=precip_colormap,
title=f"{vname} input",
datashader=datashader,
)
single_plot(
fig,
ax[4],
lon=lon,
lat=lat,
data=pred - input_,
cmap="bwr",
norm=TwoSlopeNorm(vcenter=0.0),
title=f"{vname} increment [pred - input]",
datashader=datashader,
)
single_plot(
fig,
ax[5],
lon=lon,
lat=lat,
data=truth - input_,
cmap="bwr",
norm=TwoSlopeNorm(vcenter=0.0),
title=f"{vname} persist err",
datashader=datashader,
)
else:
single_plot(fig, ax[0], lon, lat, input_, norm=norm, title=f"{vname} input", datashader=datashader)
single_plot(
Expand Down
Loading