Skip to content

Commit

Permalink
Fixes from the review: change the axe getting method.
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykraftdlr committed Sep 6, 2023
1 parent 026e4e1 commit f7eeefc
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 47 deletions.
82 changes: 36 additions & 46 deletions esmvaltool/diag_scripts/monitor/multi_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@
figure_kwargs: dict, optional
Optional keyword arguments for :func:`matplotlib.pyplot.figure`. By
default, uses ``constrained_layout: true``.
group_variables_by: str, optional (default: 'short_name')
Facet which is used to create variable groups. For each variable group, an
individual plot is created.
plots: dict, optional
Plot types plotted by this diagnostic (see list above). Dictionary keys
must be ``timeseries``, ``annual_cycle``, ``map``, ``zonal_mean_profile``,
Expand Down Expand Up @@ -488,23 +485,20 @@ def __init__(self, config):
self.cfg = deepcopy(self.cfg)
self.cfg.setdefault('facet_used_for_labels', 'dataset')
self.cfg.setdefault('figure_kwargs', {'constrained_layout': True})
self.cfg.setdefault('group_variables_by', 'short_name')
self.cfg.setdefault('savefig_kwargs', {
'bbox_inches': 'tight',
'dpi': 300,
'orientation': 'landscape',
})
self.cfg.setdefault('seaborn_settings', {'style': 'ticks'})
logger.info("Using facet '%s' to group variables",
self.cfg['group_variables_by'])
logger.info("Using facet '%s' to create labels",
self.cfg['facet_used_for_labels'])

# Load input data
self.input_data = self._load_and_preprocess_data()
self.grouped_input_data = group_metadata(
self.input_data,
self.cfg['group_variables_by'],
'short_name',
sort=self.cfg['facet_used_for_labels'],
)

Expand Down Expand Up @@ -1235,6 +1229,7 @@ def _plot_hovmoeller_time_vs_lat_or_lon_with_ref(self, plot_func, dataset,
# Make sure that the data has the correct dimensions
cube = dataset['cube']
ref_cube = ref_dataset['cube']
print(cube)
dim_coords_dat = self._check_cube_dimensions(cube, plot_type)
self._check_cube_dimensions(ref_cube, plot_type)

Expand All @@ -1258,12 +1253,12 @@ def _plot_hovmoeller_time_vs_lat_or_lon_with_ref(self, plot_func, dataset,
plot_data = plot_func(cube, **plot_kwargs)
axes_data.set_title(self._get_label(dataset), pad=3.0)
axes_data.set_ylabel('Time / Year')
plt.gca().yaxis.set_major_formatter(mdates.DateFormatter(
axes_data.get_yaxis().set_major_formatter(mdates.DateFormatter(
self.plots[plot_type]['time_format']))
if self.plots[plot_type]['show_y_minor_ticks']:
plt.gca().yaxis.set_minor_locator(AutoMinorLocator())
axes_data.get_yaxis().set_minor_locator(AutoMinorLocator())
if self.plots[plot_type]['show_x_minor_ticks']:
plt.gca().xaxis.set_minor_locator(AutoMinorLocator())
axes_data.get_xaxis().set_minor_locator(AutoMinorLocator())

# Plot reference dataset (top right)
# Note: make sure to use the same vmin and vmax than the top left
Expand All @@ -1277,6 +1272,7 @@ def _plot_hovmoeller_time_vs_lat_or_lon_with_ref(self, plot_func, dataset,
plot_ref = plot_func(ref_cube, **plot_kwargs)
axes_ref.set_title(self._get_label(ref_dataset), pad=3.0)
plt.setp(axes_ref.get_yticklabels(), visible=False)
# self._add_stats(plot_type, axes_ref, dim_coords_ref, ref_dataset)

# Add colorbar(s)
self._add_colorbar(plot_type, plot_data, plot_ref, axes_data,
Expand Down Expand Up @@ -1315,7 +1311,7 @@ def _plot_hovmoeller_time_vs_lat_or_lon_with_ref(self, plot_func, dataset,
self._process_pyplot_kwargs(plot_type, dataset)

# Rasterization
if self.plots[plot_type]['rasterize']:
if self.plots[plot_type]['rasterize']: # TODO: not working?
self._set_rasterized([axes_data, axes_ref, axes_bias])

# File paths
Expand Down Expand Up @@ -1348,8 +1344,8 @@ def _plot_hovmoeller_time_vs_lat_or_lon_without_ref(self, plot_func,
axes = fig.add_subplot()
plot_kwargs = self._get_plot_kwargs(plot_type, dataset)
plot_kwargs['axes'] = axes

# Make sure time is on y-axis

plot_kwargs['coords'] = list(reversed(dim_coords_dat))
plot_hovmoeller = plot_func(cube, **plot_kwargs)

Expand All @@ -1370,16 +1366,17 @@ def _plot_hovmoeller_time_vs_lat_or_lon_without_ref(self, plot_func,
elif "longitude" in dim_coords_dat:
axes.set_xlabel('Longitude / °E')
axes.set_ylabel('Time / Year')
plt.gca().yaxis.set_major_formatter(mdates.DateFormatter(
axes.get_yaxis().set_major_formatter(mdates.DateFormatter(
self.plots[plot_type]['time_format'])
)
if self.plots[plot_type]['show_y_minor_ticks']:
plt.gca().yaxis.set_minor_locator(AutoMinorLocator())
axes.get_yaxis().set_minor_locator(AutoMinorLocator())
if self.plots[plot_type]['show_x_minor_ticks']:
plt.gca().xaxis.set_minor_locator(AutoMinorLocator())
axes.get_xaxis().set_minor_locator(AutoMinorLocator())
self._process_pyplot_kwargs(plot_type, dataset)

# Rasterization
# TODO: seems not to work
if self.plots[plot_type]['rasterize']:
self._set_rasterized([axes])

Expand Down Expand Up @@ -1455,21 +1452,21 @@ def _get_multi_dataset_facets(datasets):
multi_dataset_facets[key] = f'ambiguous_{key}'
return multi_dataset_facets

def _get_reference_dataset(self, datasets):
@staticmethod
def _get_reference_dataset(datasets, short_name):
"""Extract reference dataset."""
variable = datasets[0][self.cfg['group_variables_by']]
ref_datasets = [d for d in datasets if
d.get('reference_for_monitor_diags', False)]
if len(ref_datasets) > 1:
raise ValueError(
f"Expected at most 1 reference dataset (with "
f"'reference_for_monitor_diags: true' for variable "
f"'{variable}', got {len(ref_datasets):d}")
f"'{short_name}', got {len(ref_datasets):d}")
if ref_datasets:
return ref_datasets[0]
return None

def create_timeseries_plot(self, datasets):
def create_timeseries_plot(self, datasets, short_name):
"""Create time series plot."""
plot_type = 'timeseries'
if plot_type not in self.plots:
Expand Down Expand Up @@ -1512,10 +1509,7 @@ def create_timeseries_plot(self, datasets):
multi_dataset_facets = self._get_multi_dataset_facets(datasets)
axes.set_title(multi_dataset_facets['long_name'])
axes.set_xlabel('Time')
axes.set_ylabel(
f"{multi_dataset_facets[self.cfg['group_variables_by']]} "
f"[{multi_dataset_facets['units']}]"
)
axes.set_ylabel(f"{short_name} [{multi_dataset_facets['units']}]")
gridline_kwargs = self._get_gridline_kwargs(plot_type)
if gridline_kwargs is not False:
axes.grid(**gridline_kwargs)
Expand Down Expand Up @@ -1555,7 +1549,7 @@ def create_timeseries_plot(self, datasets):
provenance_logger.log(plot_path, provenance_record)
provenance_logger.log(netcdf_path, provenance_record)

def create_annual_cycle_plot(self, datasets):
def create_annual_cycle_plot(self, datasets, short_name):
"""Create annual cycle plot."""
plot_type = 'annual_cycle'
if plot_type not in self.plots:
Expand Down Expand Up @@ -1586,10 +1580,7 @@ def create_annual_cycle_plot(self, datasets):
multi_dataset_facets = self._get_multi_dataset_facets(datasets)
axes.set_title(multi_dataset_facets['long_name'])
axes.set_xlabel('Month')
axes.set_ylabel(
f"{multi_dataset_facets[self.cfg['group_variables_by']]} "
f"[{multi_dataset_facets['units']}]"
)
axes.set_ylabel(f"{short_name} [{multi_dataset_facets['units']}]")
axes.set_xticks(range(1, 13), [str(m) for m in range(1, 13)])
gridline_kwargs = self._get_gridline_kwargs(plot_type)
if gridline_kwargs is not False:
Expand Down Expand Up @@ -1630,7 +1621,7 @@ def create_annual_cycle_plot(self, datasets):
provenance_logger.log(plot_path, provenance_record)
provenance_logger.log(netcdf_path, provenance_record)

def create_map_plot(self, datasets):
def create_map_plot(self, datasets, short_name):
"""Create map plot."""
plot_type = 'map'
if plot_type not in self.plots:
Expand All @@ -1640,7 +1631,7 @@ def create_map_plot(self, datasets):
raise ValueError(f"No input data to plot '{plot_type}' given")

# Get reference dataset if possible
ref_dataset = self._get_reference_dataset(datasets)
ref_dataset = self._get_reference_dataset(datasets, short_name)
if ref_dataset is None:
logger.info("Plotting %s without reference dataset", plot_type)
else:
Expand Down Expand Up @@ -1706,7 +1697,7 @@ def create_map_plot(self, datasets):
for netcdf_path in netcdf_paths:
provenance_logger.log(netcdf_path, provenance_record)

def create_zonal_mean_profile_plot(self, datasets):
def create_zonal_mean_profile_plot(self, datasets, short_name):
"""Create zonal mean profile plot."""
plot_type = 'zonal_mean_profile'
if plot_type not in self.plots:
Expand All @@ -1716,7 +1707,7 @@ def create_zonal_mean_profile_plot(self, datasets):
raise ValueError(f"No input data to plot '{plot_type}' given")

# Get reference dataset if possible
ref_dataset = self._get_reference_dataset(datasets)
ref_dataset = self._get_reference_dataset(datasets, short_name)
if ref_dataset is None:
logger.info("Plotting %s without reference dataset", plot_type)
else:
Expand Down Expand Up @@ -1784,7 +1775,7 @@ def create_zonal_mean_profile_plot(self, datasets):
for netcdf_path in netcdf_paths:
provenance_logger.log(netcdf_path, provenance_record)

def create_1d_profile_plot(self, datasets):
def create_1d_profile_plot(self, datasets, short_name):
"""Create 1D profile plot."""
plot_type = '1d_profile'
if plot_type not in self.plots:
Expand Down Expand Up @@ -1816,10 +1807,7 @@ def create_1d_profile_plot(self, datasets):

# Default plot appearance
axes.set_title(multi_dataset_facets['long_name'])
axes.set_xlabel(
f"{multi_dataset_facets[self.cfg['group_variables_by']]} "
f"[{multi_dataset_facets['units']}]"
)
axes.set_xlabel(f"{short_name} [{multi_dataset_facets['units']}]")
z_coord = cube.coord(axis='Z')
axes.set_ylabel(f'{z_coord.long_name} [{z_coord.units}]')

Expand Down Expand Up @@ -1950,11 +1938,12 @@ def create_hovmoeller_time_vs_lat_or_lon_plot(self, datasets, short_name):
io.iris_save(cube, netcdf_path)

# Provenance tracking
# TODO: update provenance
provenance_record = {
'ancestors': ancestors,
'authors': ['schlund_manuel', 'kraft_jeremy', 'ruhe_lukas'],
'caption': caption,
'plot_types': ['zonal'],
'plot_types': ['vert'],
'long_names': [dataset['long_name']],
}
with ProvenanceLogger(self.cfg) as provenance_logger:
Expand All @@ -1964,15 +1953,16 @@ def create_hovmoeller_time_vs_lat_or_lon_plot(self, datasets, short_name):

def compute(self):
"""Plot preprocessed data."""
for (var_key, datasets) in self.grouped_input_data.items():
logger.info("Processing variable %s", var_key)
self.create_timeseries_plot(datasets)
self.create_annual_cycle_plot(datasets)
self.create_map_plot(datasets)
self.create_zonal_mean_profile_plot(datasets)
self.create_1d_profile_plot(datasets)
for (short_name, datasets) in self.grouped_input_data.items():
logger.info("Processing variable %s", short_name)
self.create_timeseries_plot(datasets, short_name)
self.create_annual_cycle_plot(datasets, short_name)
self.create_map_plot(datasets, short_name)
self.create_zonal_mean_profile_plot(datasets, short_name)
self.create_1d_profile_plot(datasets, short_name)
self.create_hovmoeller_time_vs_lat_or_lon_plot(
datasets
datasets,
short_name
)


Expand Down
1 change: 0 additions & 1 deletion esmvaltool/recipes/monitor/recipe_monitor_with_refs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ diagnostics:
plot_kwargs_bias:
levels: [-10.0, -7.5, -5.0, -2.5, 0.0, 2.5, 5.0, 7.5, 10.0]


plot_1D_profiles_with_references:
description: Plot 1D profiles including reference datasets.
variables:
Expand Down

0 comments on commit f7eeefc

Please sign in to comment.