From 2e4673415185a1d0670558b877e8d37bdb5b503c Mon Sep 17 00:00:00 2001 From: abeaupre Date: Thu, 12 Jan 2023 16:55:08 -0500 Subject: [PATCH 01/25] creation fichiers --- spirograph/matplotlib/get_example_data.py | 14 ++++++++++++++ spirograph/matplotlib/timeseries.py | 7 +++++++ 2 files changed, 21 insertions(+) create mode 100644 spirograph/matplotlib/get_example_data.py create mode 100644 spirograph/matplotlib/timeseries.py diff --git a/spirograph/matplotlib/get_example_data.py b/spirograph/matplotlib/get_example_data.py new file mode 100644 index 00000000..117c6bed --- /dev/null +++ b/spirograph/matplotlib/get_example_data.py @@ -0,0 +1,14 @@ +import xarray as xr + + + + + +# ensemble percentiles Xclim - dataset where each variable represents a percentile, with lat, lon, time coords +url_1 = 'https://pavics.ouranos.ca/twitcher/ows/proxy/thredds/dodsC/birdhouse/disk2/cccs_portal/indices/Final/BCCAQv2_CMIP6/txgt_32/YS/ssp585/ensemble_percentiles/txgt_32_ann_BCCAQ2v2+ANUSPLIN300_historical+ssp585_1950-2100_percentiles.nc' + +ds_pct = xr.open_dataset(url_1, decode_cf= False) + + +da_pct = ds_pct.sel() + diff --git a/spirograph/matplotlib/timeseries.py b/spirograph/matplotlib/timeseries.py new file mode 100644 index 00000000..fff37487 --- /dev/null +++ b/spirograph/matplotlib/timeseries.py @@ -0,0 +1,7 @@ + +import numpy as np +import xarray as xr +import matplotlib.pyplot as plt + + +#decode cf = False From 6c31391129c4c8933ea0dd79b633042e03b2f66f Mon Sep 17 00:00:00 2001 From: abeaupre Date: Mon, 16 Jan 2023 13:43:15 -0500 Subject: [PATCH 02/25] w.i.p. --- spirograph/matplotlib/get_example_data.py | 18 +++-- spirograph/matplotlib/timeseries.py | 42 +++++++++- spirograph/matplotlib/util_fcts.py | 95 +++++++++++++++++++++++ 3 files changed, 146 insertions(+), 9 deletions(-) create mode 100644 spirograph/matplotlib/util_fcts.py diff --git a/spirograph/matplotlib/get_example_data.py b/spirograph/matplotlib/get_example_data.py index 117c6bed..65183d34 100644 --- a/spirograph/matplotlib/get_example_data.py +++ b/spirograph/matplotlib/get_example_data.py @@ -1,14 +1,16 @@ import xarray as xr +# ensemble percentiles Xclim - dataset where each variable represents a percentile with lat, lon, time dims +url_1 = 'https://pavics.ouranos.ca//twitcher/ows/proxy/thredds/dodsC/birdhouse/disk2/cccs_portal/indices/Final/BCCAQv2_CMIP6/tx_max/YS/ssp585/ensemble_percentiles/tx_max_ann_BCCAQ2v2+ANUSPLIN300_historical+ssp585_1950-2100_30ymean_percentiles.nc' +ds_pct_open = xr.open_dataset(url_1, decode_cf= False) +ds_pct = ds_pct_open.isel(lon = 500, lat = 250)[['tx_max_p50', 'tx_max_p10', 'tx_max_p90']] +da_pct = ds_pct['tx_max_p50'] +# DataArray simple, tasmax across time at a given point in space +url_2 = 'https://pavics.ouranos.ca/twitcher/ows/proxy/thredds/dodsC/datasets/simulations/bias_adjusted/cmip6/pcic/CanDCS-U6/day_BCCAQv2+ANUSPLIN300_UKESM1-0-LL_historical+ssp585_r1i1p1f2_gn_1950-2100.ncml' +ds_var_open = xr.open_dataset(url_2, decode_cf= False) +ds_var = ds_var_open.isel(lon = 500, lat = 250) +da_var = ds_tasmax_open.isel(lon = 500, lat = 250)['tasmax'] -# ensemble percentiles Xclim - dataset where each variable represents a percentile, with lat, lon, time coords -url_1 = 'https://pavics.ouranos.ca/twitcher/ows/proxy/thredds/dodsC/birdhouse/disk2/cccs_portal/indices/Final/BCCAQv2_CMIP6/txgt_32/YS/ssp585/ensemble_percentiles/txgt_32_ann_BCCAQ2v2+ANUSPLIN300_historical+ssp585_1950-2100_percentiles.nc' - -ds_pct = xr.open_dataset(url_1, decode_cf= False) - - -da_pct = ds_pct.sel() - diff --git a/spirograph/matplotlib/timeseries.py b/spirograph/matplotlib/timeseries.py index fff37487..a2eb3e1e 100644 --- a/spirograph/matplotlib/timeseries.py +++ b/spirograph/matplotlib/timeseries.py @@ -4,4 +4,44 @@ import matplotlib.pyplot as plt -#decode cf = False + +# To add +# language + + +## idees pour labels: arg replace_label + + +def line_ts(da, ax=None, dict_metadata=None, sub_kw=None, line_kw=None): + """ + plot unique time series from dataframe + ax: user-specified matplotlib axis + data: dataset ou dataframe xarray + dict_metadata: join figure element to xarray dataset element + sub_kw: matplotlib subplot kwargs + line_kw : maplotlib or xarray line kwargs + """ + #return empty dicts if no kwargs + kwargs = empty_dic({'sub_kw': sub_kw, 'line_kw': line_kw}) + + #initialize fig, ax if ax not provided + if not ax: + fig, ax = plt.subplots(**kwargs['sub_kw']) + + #plot + da.plot.line(ax=ax, **kwargs['line_kw']) + + #add/modify plot elements + if dict_metadata: + ax_dict_metadata(ax, dict_metadata, da, 'lines') + if 'label' in dict_metadata: + ax.legend() + return ax + + + +#out of function + +da = da_pct +dict_metadata = {'title':'my custom title'} +da_pct.plot.line() diff --git a/spirograph/matplotlib/util_fcts.py b/spirograph/matplotlib/util_fcts.py new file mode 100644 index 00000000..b5e1b995 --- /dev/null +++ b/spirograph/matplotlib/util_fcts.py @@ -0,0 +1,95 @@ + +def empty_dic(kwargs): + for k, v in kwargs.items(): + if not v: + kwargs[k] = {} + return kwargs + +def check_metadata(xr, str): + if str in xr.attrs: + return xr.attrs[str] + else: + print('Metadata "{0}" not found in "{1}"'.format(str, xr.name)) + #if str in xr.coords: pas sur si peut vraiment faire de quoi si dans coords .... plutôt loop a l'extérieur + #ajouter message d'erreur si trouve pas le metadata et aussi voir comment ajouter les dimensions de façon intelligente + +def ax_dict_metadata(ax, dict_metadata, xr, type): + if 'title' in dict_metadata: + ax.set_title(check_metadata(xr, dict_metadata['title']), wrap=True) + if 'label' in dict_metadata: #est-ce que retire et utilise uniquement dict? + eva = getattr(ax, type) + eva[-1].set_label('line 1') + if 'xlabel' in dict_metadata: + ax.set_xlabel(check_metadata(xr, dict_metadata['xlabel'])) #rotation? + if 'ylabel' in dict_metadata: + ax.set_ylabel(check_metadata(xr, dict_metadata['ylabel'])) + return ax + +#transform ds into df +def xr_pd(xr): + if "Dataset" in str(type(xr)): + return xr.to_dataframe().reset_index() + else: + return xr.to_dataframe(name='values').reset_index() + + +def da_time_serie_line(da, ax=None, dict_metadata=None, sub_kw=None, line_kw=None, logo=False): + """ + plot unique time serie from dataset + da: dataset xarray + ax: matplotlib axis + dict_metadata: join figure element to xarray dataset element + sub_kw: matplotlib subplot kwargs + line_kw : maplotlib line kwargs + """ + kwargs = empty_dic({'sub_kw': sub_kw, 'line_kw': line_kw}) + if not ax: + fig, ax = plt.subplots(**kwargs['sub_kw']) + da.plot.line(ax=ax, **kwargs['line_kw']) + if dict_metadata: + ax_dict_metadata(ax, dict_metadata, da, 'lines') + if 'label' in dict_metadata: + ax.legend() + return ax + + + +def ens_time_serie(dict_xr, ax=None, dict_metadata=None, sub_kw=None, line_kwargs=None): + """ + dict of xr object: only one no legend, more than one legend + if dictionnary, the keys will be used to name the ensembles + + if to be created over coord (ex: horizon.... fait quoi?) + + """ + kwargs = empty_dic({'sub_kw': sub_kw, 'line_kwargs': line_kwargs}) + if not ax: + fig, ax = plt.subplots(**kwargs['sub_kw']) + + if type(dict_xr) != dict: + dict_xr = {'one': dict_xr} + + if len(dict_xr) == 1: + df = xr_pd(list(dict_xr.values())[0]).drop(columns=['lat', 'lon']) + if "Dataset" in str(type(list(dict_xr.values())[0])): + df = pd.melt(df, ['time']) + sns.lineplot(data=df, x='time', y='value', ax=ax, **kwargs['line_kwargs']) + else: + n = 0 + for k, v in dict_xr.items(): + df = xr_pd(v).drop(columns=['lat', 'lon']) + if "Dataset" in str(type(v)): + df = pd.melt(df, ['time'], value_name=k) + else: + df = df.rename(columns={'values': k}) + if n == 0: + dfa = df + else: + dfa[k] = df[k] + n = n+1 + dfa = pd.melt(dfa, ['time'], list(dict_xr.keys())) + sns.lineplot(data=dfa, x='time', y='value', hue='variable', ax=ax, **kwargs['line_kwargs']) #ajouter palette horizon si possible - option rcp/ssp ou détecter automatique? + if dict_metadata: + ax_dict_metadata(ax, dict_metadata, list(dict_xr.values())[0], 'line') + return ax +##### Fin code Sarah-Claude ##### From 5f75c601d3763f04b21f1db276dd02d9d2cf0f02 Mon Sep 17 00:00:00 2001 From: abeaupre Date: Mon, 16 Jan 2023 17:05:23 -0500 Subject: [PATCH 03/25] fonctions utilitaires de base --- spirograph/matplotlib/timeseries.py | 38 +++++++++++++++------------ spirograph/matplotlib/util_fcts.py | 40 +++++++++++++++++------------ 2 files changed, 46 insertions(+), 32 deletions(-) diff --git a/spirograph/matplotlib/timeseries.py b/spirograph/matplotlib/timeseries.py index a2eb3e1e..44c195d5 100644 --- a/spirograph/matplotlib/timeseries.py +++ b/spirograph/matplotlib/timeseries.py @@ -7,41 +7,47 @@ # To add # language +# smoothing +# m??anual mode: asks for input for title, ## idees pour labels: arg replace_label -def line_ts(da, ax=None, dict_metadata=None, sub_kw=None, line_kw=None): +def line_ts(da, ax=None, use_attrs=None, sub_kw=None, line_kw=None): """ - plot unique time series from dataframe - ax: user-specified matplotlib axis - data: dataset ou dataframe xarray - dict_metadata: join figure element to xarray dataset element - sub_kw: matplotlib subplot kwargs - line_kw : maplotlib or xarray line kwargs + Plots unique time series from dataframe + + Args: + ax: user-specified matplotlib axis + data: Xarray DataArray containing the data to plot + dict_metadata: dict linking a plot element (key, e.g. 'title') + to a DataArray attribute (value, e.g. 'Description') + sub_kw: matplotlib subplot kwargs + line_kw : maplotlib or xarray line kwargs + + Returns: + matplotlib axis """ #return empty dicts if no kwargs - kwargs = empty_dic({'sub_kw': sub_kw, 'line_kw': line_kw}) + kwargs = empty_dict({'sub_kw': sub_kw, 'line_kw': line_kw}) #initialize fig, ax if ax not provided if not ax: fig, ax = plt.subplots(**kwargs['sub_kw']) #plot - da.plot.line(ax=ax, **kwargs['line_kw']) + line_1 = da.plot.line(ax=ax, **kwargs['line_kw']) + #line_1.set_label() #add/modify plot elements - if dict_metadata: - ax_dict_metadata(ax, dict_metadata, da, 'lines') - if 'label' in dict_metadata: - ax.legend() + if use_attrs: + ax_dict_metadata(ax, use_attrs, da) + return ax #out of function -da = da_pct -dict_metadata = {'title':'my custom title'} -da_pct.plot.line() +line_ts(da_pct, use_attrs= {'title': 'ccdp_name'}) diff --git a/spirograph/matplotlib/util_fcts.py b/spirograph/matplotlib/util_fcts.py index b5e1b995..d9ddac1d 100644 --- a/spirograph/matplotlib/util_fcts.py +++ b/spirograph/matplotlib/util_fcts.py @@ -1,28 +1,36 @@ -def empty_dic(kwargs): +def empty_dict(kwargs): + """Returns empty dictionaries + """ for k, v in kwargs.items(): if not v: kwargs[k] = {} return kwargs -def check_metadata(xr, str): - if str in xr.attrs: - return xr.attrs[str] +def get_metadata(xr_obj, str): + """ + Fetches attributes corresponding to their key from Xarray objects + + Args: + xr: Xarray DataArray or Dataset + str: string corresponding to an attribute key + + Returns: + Xarray attribute value as string + """ + if str in xr_obj.attrs: + return xr_obj.attrs[str] else: - print('Metadata "{0}" not found in "{1}"'.format(str, xr.name)) + raise Exception('Metadata "{0}" not found in "{1}"'.format(str, xr_obj.name)) #if str in xr.coords: pas sur si peut vraiment faire de quoi si dans coords .... plutôt loop a l'extérieur - #ajouter message d'erreur si trouve pas le metadata et aussi voir comment ajouter les dimensions de façon intelligente -def ax_dict_metadata(ax, dict_metadata, xr, type): - if 'title' in dict_metadata: - ax.set_title(check_metadata(xr, dict_metadata['title']), wrap=True) - if 'label' in dict_metadata: #est-ce que retire et utilise uniquement dict? - eva = getattr(ax, type) - eva[-1].set_label('line 1') - if 'xlabel' in dict_metadata: - ax.set_xlabel(check_metadata(xr, dict_metadata['xlabel'])) #rotation? - if 'ylabel' in dict_metadata: - ax.set_ylabel(check_metadata(xr, dict_metadata['ylabel'])) +def ax_dict_metadata(ax, use_attrs, xr_obj): + if 'title' in use_attrs: + ax.set_title(get_metadata(xr_obj, use_attrs['title']), wrap=True) + if 'xlabel' in use_attrs: + ax.set_xlabel(get_metadata(xr_obj, use_attrs['xlabel'])) #rotation? + if 'ylabel' in use_attrs: + ax.set_ylabel(get_metadata(xr_obj, use_attrs['ylabel'])) return ax #transform ds into df From b5af3d03fbba608271ab15ec95ec299d9148c1a1 Mon Sep 17 00:00:00 2001 From: abeaupre Date: Tue, 17 Jan 2023 15:03:12 -0500 Subject: [PATCH 04/25] ajout fonctionnalite basic Dataset --- spirograph/matplotlib/timeseries.py | 37 ++++++----- spirograph/matplotlib/util_fcts.py | 95 ++++++----------------------- 2 files changed, 42 insertions(+), 90 deletions(-) diff --git a/spirograph/matplotlib/timeseries.py b/spirograph/matplotlib/timeseries.py index 44c195d5..ed17a99f 100644 --- a/spirograph/matplotlib/timeseries.py +++ b/spirograph/matplotlib/timeseries.py @@ -8,46 +8,55 @@ # To add # language # smoothing -# m??anual mode: asks for input for title, +# manual mode? +# logo +# default value for attributes to use? ## idees pour labels: arg replace_label -def line_ts(da, ax=None, use_attrs=None, sub_kw=None, line_kw=None): +def line_ts(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None): """ Plots unique time series from dataframe - Args: ax: user-specified matplotlib axis - data: Xarray DataArray containing the data to plot - dict_metadata: dict linking a plot element (key, e.g. 'title') + da: Xarray DataArray containing the data to plot + use_attrs: dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description') sub_kw: matplotlib subplot kwargs line_kw : maplotlib or xarray line kwargs - Returns: matplotlib axis """ - #return empty dicts if no kwargs kwargs = empty_dict({'sub_kw': sub_kw, 'line_kw': line_kw}) - #initialize fig, ax if ax not provided if not ax: fig, ax = plt.subplots(**kwargs['sub_kw']) + #arrange data + plot_dict = {} + if str(type(data)) == "": + print('DATASET DETECTED') + for k,v in data.data_vars.items(): + plot_dict[k] = v + else: + plot_dict[data.name] = data + + print(plot_dict) + #plot - line_1 = da.plot.line(ax=ax, **kwargs['line_kw']) - #line_1.set_label() + for name, xr in plot_dict.items(): + #da.plot.line(ax=ax, **kwargs['line_kw']) # using xarray plotting + ax.plot(xr[xr.dims[0]], xr.values, label = name) #assumes the only dim is time + #add/modify plot elements if use_attrs: - ax_dict_metadata(ax, use_attrs, da) + set_plot_attrs(use_attrs, data, ax) return ax - - -#out of function +#test line_ts(da_pct, use_attrs= {'title': 'ccdp_name'}) diff --git a/spirograph/matplotlib/util_fcts.py b/spirograph/matplotlib/util_fcts.py index d9ddac1d..a6547d23 100644 --- a/spirograph/matplotlib/util_fcts.py +++ b/spirograph/matplotlib/util_fcts.py @@ -7,97 +7,40 @@ def empty_dict(kwargs): kwargs[k] = {} return kwargs -def get_metadata(xr_obj, str): +def get_attributes(xr_obj, str): """ - Fetches attributes corresponding to their key from Xarray objects - + Fetches attributes corresponding to keys from Xarray objects Args: - xr: Xarray DataArray or Dataset + xr_obj: Xarray DataArray or Dataset str: string corresponding to an attribute key - Returns: Xarray attribute value as string """ - if str in xr_obj.attrs: +if str in xr_obj.attrs: return xr_obj.attrs[str] else: - raise Exception('Metadata "{0}" not found in "{1}"'.format(str, xr_obj.name)) - #if str in xr.coords: pas sur si peut vraiment faire de quoi si dans coords .... plutôt loop a l'extérieur - -def ax_dict_metadata(ax, use_attrs, xr_obj): - if 'title' in use_attrs: - ax.set_title(get_metadata(xr_obj, use_attrs['title']), wrap=True) - if 'xlabel' in use_attrs: - ax.set_xlabel(get_metadata(xr_obj, use_attrs['xlabel'])) #rotation? - if 'ylabel' in use_attrs: - ax.set_ylabel(get_metadata(xr_obj, use_attrs['ylabel'])) - return ax - -#transform ds into df -def xr_pd(xr): - if "Dataset" in str(type(xr)): - return xr.to_dataframe().reset_index() - else: - return xr.to_dataframe(name='values').reset_index() - + raise Exception('Attribute "{0}" not found in "{1}"'.format(str, xr_obj.name)) -def da_time_serie_line(da, ax=None, dict_metadata=None, sub_kw=None, line_kw=None, logo=False): +def set_plot_attrs(attr_dict, xr_obj, ax): """ - plot unique time serie from dataset - da: dataset xarray - ax: matplotlib axis - dict_metadata: join figure element to xarray dataset element - sub_kw: matplotlib subplot kwargs - line_kw : maplotlib line kwargs + Sets plot elements according to DataArray attributes. Uses get_attributes() + Args: + use_attrs (dict): dict containing specified attribute keys + xr_obj: Xarray DataArray + ax: matplotlib axis + Returns: + matplotlib axis """ - kwargs = empty_dic({'sub_kw': sub_kw, 'line_kw': line_kw}) - if not ax: - fig, ax = plt.subplots(**kwargs['sub_kw']) - da.plot.line(ax=ax, **kwargs['line_kw']) - if dict_metadata: - ax_dict_metadata(ax, dict_metadata, da, 'lines') - if 'label' in dict_metadata: - ax.legend() + if 'title' in attr_dict: + ax.set_title(get_attributes(xr_obj, attr_dict['title']), wrap=True) + if 'xlabel' in attr_dict: + ax.set_xlabel(get_attributes(xr_obj, attr_dict['xlabel'])) #rotation? + if 'ylabel' in attr_dict: + ax.set_ylabel(get_attributes(xr_obj, attr_dict['ylabel'])) return ax -def ens_time_serie(dict_xr, ax=None, dict_metadata=None, sub_kw=None, line_kwargs=None): - """ - dict of xr object: only one no legend, more than one legend - if dictionnary, the keys will be used to name the ensembles - if to be created over coord (ex: horizon.... fait quoi?) - """ - kwargs = empty_dic({'sub_kw': sub_kw, 'line_kwargs': line_kwargs}) - if not ax: - fig, ax = plt.subplots(**kwargs['sub_kw']) - if type(dict_xr) != dict: - dict_xr = {'one': dict_xr} - - if len(dict_xr) == 1: - df = xr_pd(list(dict_xr.values())[0]).drop(columns=['lat', 'lon']) - if "Dataset" in str(type(list(dict_xr.values())[0])): - df = pd.melt(df, ['time']) - sns.lineplot(data=df, x='time', y='value', ax=ax, **kwargs['line_kwargs']) - else: - n = 0 - for k, v in dict_xr.items(): - df = xr_pd(v).drop(columns=['lat', 'lon']) - if "Dataset" in str(type(v)): - df = pd.melt(df, ['time'], value_name=k) - else: - df = df.rename(columns={'values': k}) - if n == 0: - dfa = df - else: - dfa[k] = df[k] - n = n+1 - dfa = pd.melt(dfa, ['time'], list(dict_xr.keys())) - sns.lineplot(data=dfa, x='time', y='value', hue='variable', ax=ax, **kwargs['line_kwargs']) #ajouter palette horizon si possible - option rcp/ssp ou détecter automatique? - if dict_metadata: - ax_dict_metadata(ax, dict_metadata, list(dict_xr.values())[0], 'line') - return ax -##### Fin code Sarah-Claude ##### From 48f2636d0d63a8b79db8c96474b55fccdfdb1784 Mon Sep 17 00:00:00 2001 From: abeaupre Date: Tue, 17 Jan 2023 18:45:26 -0500 Subject: [PATCH 05/25] started ensemble functionality ad created sort_lines fct --- spirograph/matplotlib/timeseries.py | 15 +++++++-- spirograph/matplotlib/util_fcts.py | 49 +++++++++++++++++++++++++++-- 2 files changed, 59 insertions(+), 5 deletions(-) diff --git a/spirograph/matplotlib/timeseries.py b/spirograph/matplotlib/timeseries.py index ed17a99f..8546ef8a 100644 --- a/spirograph/matplotlib/timeseries.py +++ b/spirograph/matplotlib/timeseries.py @@ -16,7 +16,7 @@ ## idees pour labels: arg replace_label -def line_ts(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None): +def line_ts(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None, ensemble = False): """ Plots unique time series from dataframe Args: @@ -37,24 +37,33 @@ def line_ts(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None): #arrange data plot_dict = {} if str(type(data)) == "": - print('DATASET DETECTED') for k,v in data.data_vars.items(): plot_dict[k] = v else: plot_dict[data.name] = data - print(plot_dict) + #set up for ensemble + sorted_line_y = [] + sorted_line_x = [] #plot for name, xr in plot_dict.items(): #da.plot.line(ax=ax, **kwargs['line_kw']) # using xarray plotting ax.plot(xr[xr.dims[0]], xr.values, label = name) #assumes the only dim is time + if ensemble is True: + sorted_line_x.append(xr[xr.dims[0]]) + sorted_line_y.append(xr.values) + + if ensemble is True: + ax.fill_between() #add/modify plot elements if use_attrs: set_plot_attrs(use_attrs, data, ax) + ax.legend() + return ax #test diff --git a/spirograph/matplotlib/util_fcts.py b/spirograph/matplotlib/util_fcts.py index a6547d23..3477b7cc 100644 --- a/spirograph/matplotlib/util_fcts.py +++ b/spirograph/matplotlib/util_fcts.py @@ -1,4 +1,8 @@ +import pandas as pd +import xarray as xr + + def empty_dict(kwargs): """Returns empty dictionaries """ @@ -16,8 +20,8 @@ def get_attributes(xr_obj, str): Returns: Xarray attribute value as string """ -if str in xr_obj.attrs: - return xr_obj.attrs[str] + if str in xr_obj.attrs: + return xr_obj.attrs[str] else: raise Exception('Attribute "{0}" not found in "{1}"'.format(str, xr_obj.name)) @@ -40,6 +44,47 @@ def set_plot_attrs(attr_dict, xr_obj, ax): return ax +def sort_lines(array_dict): + """ + Sorts same-length parallel (not crossing across x-y coordinates) arrays by y coordinates + Args: + array_dict: dict of arrays, must contain an odd number of arrays + Returns: + dict of names + To do: different lengths?, warning if even number + """ + ref_values = {} + sorted_lines = {} + for name, xr in array_dict.items(): + ref_values[name] = int(xr[int(len(xr)/2)]) + sorted_series = pd.Series(ref_values).sort_values() + sorted_lines['upper'] = sorted_series.idxmax() + sorted_lines['lower'] = sorted_series.idxmin() + + return sorted_lines + + + + + + + + + + +# lnx = np.arange(1,10,1) +# ln1 = lnx + 3 + 1*np.random.rand() +# ln2 = lnx + 10 + 3*np.random.rand() +# ln3 = lnx + 6 + 2*np.random.rand() +# +# fig, ax = plt.subplots(figsize = (4,3)) +# ax.plot(lnx, ln1) +# ax.plot(lnx, ln2) +# ax.plot(lnx, ln3) +# +# ax.fill_between(lnx, ln1,ln2, alpha = 0.2, color = 'red') +# ax.fill_between(lnx, ln1,ln3, alpha = 0.2, color = 'blue') +# plt.show() From b07a7a3b08584d435f9ad3acdbf7cbbc176b37ca Mon Sep 17 00:00:00 2001 From: abeaupre Date: Wed, 18 Jan 2023 12:07:27 -0500 Subject: [PATCH 06/25] ensemble function linestyle improvements --- spirograph/matplotlib/timeseries.py | 40 +++++++++++++++-------------- spirograph/matplotlib/util_fcts.py | 15 ++++++----- 2 files changed, 29 insertions(+), 26 deletions(-) diff --git a/spirograph/matplotlib/timeseries.py b/spirograph/matplotlib/timeseries.py index 8546ef8a..8b9d8731 100644 --- a/spirograph/matplotlib/timeseries.py +++ b/spirograph/matplotlib/timeseries.py @@ -16,16 +16,16 @@ ## idees pour labels: arg replace_label -def line_ts(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None, ensemble = False): +def line_ts(data, ensemble = False, ax=None, use_attrs=None, sub_kw=None, line_kw=None): """ - Plots unique time series from dataframe + Plots unique time series from 1D dataframe or dataset Args: ax: user-specified matplotlib axis da: Xarray DataArray containing the data to plot use_attrs: dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description') sub_kw: matplotlib subplot kwargs - line_kw : maplotlib or xarray line kwargs + line_kw : matplotlib or xarray line kwargs Returns: matplotlib axis """ @@ -35,28 +35,29 @@ def line_ts(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None, ensemble = fig, ax = plt.subplots(**kwargs['sub_kw']) #arrange data - plot_dict = {} + array_dict = {} if str(type(data)) == "": - for k,v in data.data_vars.items(): - plot_dict[k] = v + for k, v in data.data_vars.items(): + array_dict[k] = v + if ensemble is True: + sorted_lines = sort_lines(array_dict) else: - plot_dict[data.name] = data - - #set up for ensemble - sorted_line_y = [] - sorted_line_x = [] + array_dict[data.name] = data #plot - for name, xr in plot_dict.items(): - #da.plot.line(ax=ax, **kwargs['line_kw']) # using xarray plotting - ax.plot(xr[xr.dims[0]], xr.values, label = name) #assumes the only dim is time + if ensemble is True: - if ensemble is True: - sorted_line_x.append(xr[xr.dims[0]]) - sorted_line_y.append(xr.values) + ax.plot(array_dict[sorted_lines['middle']][array_dict[sorted_lines['middle']].dims[0]], + array_dict[sorted_lines['middle']].values, **kwargs['line_kw']) - if ensemble is True: - ax.fill_between() + ax.fill_between(array_dict[sorted_lines['lower']][array_dict[sorted_lines['lower']].dims[0]], + array_dict[sorted_lines['lower']].values, + array_dict[sorted_lines['upper']].values, + alpha = 0.2) + else: + for name, arr in array_dict.items(): + #da.plot.line(ax=ax, **kwargs['line_kw']) # using xarray plotting + ax.plot(arr[arr.dims[0]], arr.values,label = name, **kwargs['line_kw']) #add/modify plot elements if use_attrs: @@ -69,3 +70,4 @@ def line_ts(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None, ensemble = #test line_ts(da_pct, use_attrs= {'title': 'ccdp_name'}) +line_ts(ds_pct, use_attrs= {'title': 'ccdp_name'}, ensemble = True) diff --git a/spirograph/matplotlib/util_fcts.py b/spirograph/matplotlib/util_fcts.py index 3477b7cc..0aefffd7 100644 --- a/spirograph/matplotlib/util_fcts.py +++ b/spirograph/matplotlib/util_fcts.py @@ -46,20 +46,21 @@ def set_plot_attrs(attr_dict, xr_obj, ax): def sort_lines(array_dict): """ - Sorts same-length parallel (not crossing across x-y coordinates) arrays by y coordinates + Sorts and labels same-length arrays that plot as parallel lines in x,y space + according to the highest and lowest along the y-axis Args: - array_dict: dict of arrays, must contain an odd number of arrays + array_dict: dict of arrays. Returns: - dict of names - To do: different lengths?, warning if even number + dict """ ref_values = {} sorted_lines = {} - for name, xr in array_dict.items(): - ref_values[name] = int(xr[int(len(xr)/2)]) + for name, arr in array_dict.items(): + ref_values[name] = int(arr[int(len(arr)/2)]) sorted_series = pd.Series(ref_values).sort_values() - sorted_lines['upper'] = sorted_series.idxmax() sorted_lines['lower'] = sorted_series.idxmin() + sorted_lines['upper'] = sorted_series.idxmax() + sorted_lines['middle'] = sorted_series.index[int(len(sorted_series)/2 - 0.5)] # -0.5 is + 0.5 - 1, to account for 0-indexing return sorted_lines From 53591ac0194f636502c4d45e6f143737d75a92ed Mon Sep 17 00:00:00 2001 From: abeaupre Date: Wed, 18 Jan 2023 18:19:06 -0500 Subject: [PATCH 07/25] implemented attributes pull from Xarray objects --- spirograph/matplotlib/timeseries.py | 22 ++++++++++------ spirograph/matplotlib/util_fcts.py | 40 ++++++++++++++++++++++------- 2 files changed, 45 insertions(+), 17 deletions(-) diff --git a/spirograph/matplotlib/timeseries.py b/spirograph/matplotlib/timeseries.py index 8b9d8731..7601dbbf 100644 --- a/spirograph/matplotlib/timeseries.py +++ b/spirograph/matplotlib/timeseries.py @@ -8,9 +8,9 @@ # To add # language # smoothing -# manual mode? # logo -# default value for attributes to use? +# transformer dates de days since... a dates? +# use_attrs = default au lieu de None?? ## idees pour labels: arg replace_label @@ -47,21 +47,26 @@ def line_ts(data, ensemble = False, ax=None, use_attrs=None, sub_kw=None, line_k #plot if ensemble is True: - ax.plot(array_dict[sorted_lines['middle']][array_dict[sorted_lines['middle']].dims[0]], + line_1 = ax.plot(array_dict[sorted_lines['middle']][array_dict[sorted_lines['middle']].dims[0]], array_dict[sorted_lines['middle']].values, **kwargs['line_kw']) ax.fill_between(array_dict[sorted_lines['lower']][array_dict[sorted_lines['lower']].dims[0]], array_dict[sorted_lines['lower']].values, array_dict[sorted_lines['upper']].values, - alpha = 0.2) + color = line_1[0].get_color(), + edgecolor = 'white', alpha = 0.2) else: for name, arr in array_dict.items(): - #da.plot.line(ax=ax, **kwargs['line_kw']) # using xarray plotting ax.plot(arr[arr.dims[0]], arr.values,label = name, **kwargs['line_kw']) #add/modify plot elements + plot_attrs = default_attrs(data) if use_attrs: - set_plot_attrs(use_attrs, data, ax) + for k, v in use_attrs.items(): + plot_attrs[k] = v + + set_plot_attrs(plot_attrs, data, ax) + ax.legend() @@ -69,5 +74,6 @@ def line_ts(data, ensemble = False, ax=None, use_attrs=None, sub_kw=None, line_k #test -line_ts(da_pct, use_attrs= {'title': 'ccdp_name'}) -line_ts(ds_pct, use_attrs= {'title': 'ccdp_name'}, ensemble = True) +line_ts(da_pct, use_attrs={'title': 'ccdp_name'}, line_kw = {'color': 'red'}) +line_ts(ds_pct, ensemble=True, line_kw={'color': 'red'}) +line_ts(ds_pct, use_attrs= {'title': 'ccdp_name'}, ensemble = True,line_kw = {'color':'red'}) diff --git a/spirograph/matplotlib/util_fcts.py b/spirograph/matplotlib/util_fcts.py index 0aefffd7..d26140fc 100644 --- a/spirograph/matplotlib/util_fcts.py +++ b/spirograph/matplotlib/util_fcts.py @@ -13,7 +13,7 @@ def empty_dict(kwargs): def get_attributes(xr_obj, str): """ - Fetches attributes corresponding to keys from Xarray objects + Fetches attributes or dims corresponding to keys from Xarray objects Args: xr_obj: Xarray DataArray or Dataset str: string corresponding to an attribute key @@ -21,10 +21,31 @@ def get_attributes(xr_obj, str): Xarray attribute value as string """ if str in xr_obj.attrs: - return xr_obj.attrs[str] + return xr_obj.attrs[str] + elif str in xr_obj.dims: + return str #special case because DataArray and Dataset dims are not the same types else: raise Exception('Attribute "{0}" not found in "{1}"'.format(str, xr_obj.name)) +def default_attrs(xr_obj): + """ + Builds a dictionary of default Xarray object attributes to use as plot labels, + using similar behaviour to Xarray.DataArray.plot() + + Args: + xr_obj: Xarray object (DataArray or Dataset) + Returns: + dict of key-value pairs of the format (plot_element:attribute_name) + """ + default = {} + default['title'] = 'long_name' + default['xlabel'] = 'time' + default['ylabel'] = 'standard_name' + default['yunits'] = 'units' + + return default + + def set_plot_attrs(attr_dict, xr_obj, ax): """ Sets plot elements according to DataArray attributes. Uses get_attributes() @@ -34,13 +55,15 @@ def set_plot_attrs(attr_dict, xr_obj, ax): ax: matplotlib axis Returns: matplotlib axis + Todo: include lat,lon coordinates in title, add warning if input not in list (e.g. y_label) """ if 'title' in attr_dict: ax.set_title(get_attributes(xr_obj, attr_dict['title']), wrap=True) if 'xlabel' in attr_dict: ax.set_xlabel(get_attributes(xr_obj, attr_dict['xlabel'])) #rotation? if 'ylabel' in attr_dict: - ax.set_ylabel(get_attributes(xr_obj, attr_dict['ylabel'])) + ax.set_ylabel(get_attributes(xr_obj, attr_dict['ylabel'])+ ' [' + + get_attributes(xr_obj, attr_dict['yunits'])+ ']') return ax @@ -76,17 +99,16 @@ def sort_lines(array_dict): # lnx = np.arange(1,10,1) # ln1 = lnx + 3 + 1*np.random.rand() # ln2 = lnx + 10 + 3*np.random.rand() -# ln3 = lnx + 6 + 2*np.random.rand() +# #ln3 = lnx + 6 + 2*np.random.rand() # # fig, ax = plt.subplots(figsize = (4,3)) -# ax.plot(lnx, ln1) -# ax.plot(lnx, ln2) -# ax.plot(lnx, ln3) +# line_1 = ax.plot(lnx, ln1) +# line_2 = ax.plot(lnx, ln2) +# #line_3 = ax.plot(lnx, ln3) # # ax.fill_between(lnx, ln1,ln2, alpha = 0.2, color = 'red') -# ax.fill_between(lnx, ln1,ln3, alpha = 0.2, color = 'blue') +# #ax.fill_between(lnx, ln1,ln3, alpha = 0.2, color = 'blue') # plt.show() - From 7b332e49a8cd438d6c59ec2c3905656ab6454c34 Mon Sep 17 00:00:00 2001 From: abeaupre Date: Thu, 19 Jan 2023 17:13:16 -0500 Subject: [PATCH 08/25] timeseries todo list updated --- spirograph/matplotlib/get_example_data.py | 13 ++++-- spirograph/matplotlib/timeseries.py | 52 +++++++++++++++++------ spirograph/matplotlib/util_fcts.py | 8 +++- 3 files changed, 53 insertions(+), 20 deletions(-) diff --git a/spirograph/matplotlib/get_example_data.py b/spirograph/matplotlib/get_example_data.py index 65183d34..c0aa1ce0 100644 --- a/spirograph/matplotlib/get_example_data.py +++ b/spirograph/matplotlib/get_example_data.py @@ -1,16 +1,21 @@ import xarray as xr +from xclim import ensembles # ensemble percentiles Xclim - dataset where each variable represents a percentile with lat, lon, time dims url_1 = 'https://pavics.ouranos.ca//twitcher/ows/proxy/thredds/dodsC/birdhouse/disk2/cccs_portal/indices/Final/BCCAQv2_CMIP6/tx_max/YS/ssp585/ensemble_percentiles/tx_max_ann_BCCAQ2v2+ANUSPLIN300_historical+ssp585_1950-2100_30ymean_percentiles.nc' -ds_pct_open = xr.open_dataset(url_1, decode_cf= False) -ds_pct = ds_pct_open.isel(lon = 500, lat = 250)[['tx_max_p50', 'tx_max_p10', 'tx_max_p90']] +ds_pct_open = xr.open_dataset(url_1, decode_timedelta=False) +ds_pct = ds_pct_open.isel(lon=500, lat=250)[['tx_max_p50', 'tx_max_p10', 'tx_max_p90']] da_pct = ds_pct['tx_max_p50'] + + + # DataArray simple, tasmax across time at a given point in space url_2 = 'https://pavics.ouranos.ca/twitcher/ows/proxy/thredds/dodsC/datasets/simulations/bias_adjusted/cmip6/pcic/CanDCS-U6/day_BCCAQv2+ANUSPLIN300_UKESM1-0-LL_historical+ssp585_r1i1p1f2_gn_1950-2100.ncml' -ds_var_open = xr.open_dataset(url_2, decode_cf= False) +ds_var_open = xr.open_dataset(url_2, decode_timedelta= False) ds_var = ds_var_open.isel(lon = 500, lat = 250) -da_var = ds_tasmax_open.isel(lon = 500, lat = 250)['tasmax'] + + diff --git a/spirograph/matplotlib/timeseries.py b/spirograph/matplotlib/timeseries.py index 7601dbbf..13c71356 100644 --- a/spirograph/matplotlib/timeseries.py +++ b/spirograph/matplotlib/timeseries.py @@ -2,26 +2,33 @@ import numpy as np import xarray as xr import matplotlib.pyplot as plt +import pandas as pd # To add -# language -# smoothing -# logo -# transformer dates de days since... a dates? -# use_attrs = default au lieu de None?? +# language +# logo +# transformer dates de days since... a dates? +# erreur ensemble+DA +# xlim +# assigning kwargs to different lines: list?? +# input = dict(nom:ds ou nom:da), avec noms qui deviennent legend +# detect if ensemble (mean, max, min _pNN), detect if in coords + ## fonction qui label chaque entree comme {global_label:, type: ds/da, ens = da/ds/var_ens/dim_ens}, +# assumer que 'time' est la dimension, et fct qui regarde +# lorsque plusieurs datasets, prendre +# lorsque dataset n Date: Mon, 23 Jan 2023 18:20:12 -0500 Subject: [PATCH 09/25] restructuration de la fonction pour accepter un dictionnaire de Datasets --- spirograph/matplotlib/get_example_data.py | 8 +- spirograph/matplotlib/timeseries.py | 100 +++++++++++++--------- spirograph/matplotlib/util_fcts.py | 58 +++++++------ 3 files changed, 95 insertions(+), 71 deletions(-) diff --git a/spirograph/matplotlib/get_example_data.py b/spirograph/matplotlib/get_example_data.py index c0aa1ce0..487f09f7 100644 --- a/spirograph/matplotlib/get_example_data.py +++ b/spirograph/matplotlib/get_example_data.py @@ -1,5 +1,6 @@ import xarray as xr -from xclim import ensembles +import numpy as np + # ensemble percentiles Xclim - dataset where each variable represents a percentile with lat, lon, time dims url_1 = 'https://pavics.ouranos.ca//twitcher/ows/proxy/thredds/dodsC/birdhouse/disk2/cccs_portal/indices/Final/BCCAQv2_CMIP6/tx_max/YS/ssp585/ensemble_percentiles/tx_max_ann_BCCAQ2v2+ANUSPLIN300_historical+ssp585_1950-2100_30ymean_percentiles.nc' @@ -8,7 +9,10 @@ da_pct = ds_pct['tx_max_p50'] - +data = np.random.rand(4,3) +time = [1,2,3,4] +pct = [15,50,95] +datest = xr.DataArray(data, coords = [time, pct], dims = ['time', 'percentiles']) # DataArray simple, tasmax across time at a given point in space url_2 = 'https://pavics.ouranos.ca/twitcher/ows/proxy/thredds/dodsC/datasets/simulations/bias_adjusted/cmip6/pcic/CanDCS-U6/day_BCCAQv2+ANUSPLIN300_UKESM1-0-LL_historical+ssp585_r1i1p1f2_gn_1950-2100.ncml' diff --git a/spirograph/matplotlib/timeseries.py b/spirograph/matplotlib/timeseries.py index 13c71356..f4bc70e3 100644 --- a/spirograph/matplotlib/timeseries.py +++ b/spirograph/matplotlib/timeseries.py @@ -14,11 +14,10 @@ # xlim # assigning kwargs to different lines: list?? # input = dict(nom:ds ou nom:da), avec noms qui deviennent legend -# detect if ensemble (mean, max, min _pNN), detect if in coords - ## fonction qui label chaque entree comme {global_label:, type: ds/da, ens = da/ds/var_ens/dim_ens}, # assumer que 'time' est la dimension, et fct qui regarde -# lorsque plusieurs datasets, prendre +# lorsque plusieurs datasets, prendre le 1er # lorsque dataset n": - for k, v in data.data_vars.items(): - array_dict[k] = v - if ensemble is True: - sorted_lines = sort_lines(array_dict) - else: - array_dict[data.name] = data - - #plot - if ensemble is True: - - line_1 = ax.plot(array_dict[sorted_lines['middle']][array_dict[sorted_lines['middle']].dims[0]], - array_dict[sorted_lines['middle']].values, **kwargs['line_kw']) - - ax.fill_between(array_dict[sorted_lines['lower']][array_dict[sorted_lines['lower']].dims[0]], - array_dict[sorted_lines['lower']].values, - array_dict[sorted_lines['upper']].values, - color = line_1[0].get_color(), - edgecolor = 'white', alpha = 0.2) - else: - for name, arr in array_dict.items(): - ax.plot(arr[arr.dims[0]], arr.values, label = name, **kwargs['line_kw']) - - #add/modify plot elements - plot_attrs = default_attrs() - if use_attrs: - for k, v in use_attrs.items(): - plot_attrs[k] = v - set_plot_attrs(plot_attrs, data, ax) + #add/modify plot elements according to the first entry + set_plot_attrs(plot_attrs, list(data.values())[0], ax) ax.legend() @@ -95,6 +112,7 @@ def line_ts(data, ensemble = False, ax=None, use_attrs=None, sub_kw=None, line_k line_ts(da_pct, line_kw = {'color': 'red'}) line_ts(ds_pct, use_attrs= {'title': 'ccdp_name'}) line_ts(ds_pct, ensemble=True, line_kw={'color': 'red'}) +line_ts({"BOURGEONNOISERIES":ds_pct}) mod_da_pct = da_pct + da_pct*0.05 diff --git a/spirograph/matplotlib/util_fcts.py b/spirograph/matplotlib/util_fcts.py index c786a29b..cfddebee 100644 --- a/spirograph/matplotlib/util_fcts.py +++ b/spirograph/matplotlib/util_fcts.py @@ -1,19 +1,38 @@ -import numpy as np -import xarray as xr -import matplotlib.pyplot as plt -import pandas as pd +import pandas as pd +import re +def get_array_categ(array): + """Returns an array category + PCT_VAR_ENS: ensemble of percentiles stored as variables + PCT_DIM_ENS: ensemble of percentiles stored as dimension coordinates + STATS_VAR_ENS: ensemble of statistics (min, mean, max) stored as variables + NON_ENS_DS: dataset of individual lines, not an ensemble + DA: DataArray + Args: + data_dict: Xarray Dataset or DataArray + Returns + str + """ + + if str(type(array)) == "": + if pd.notnull([re.search("_p[0-9]{1,2}", var) for var in array.data_vars]).sum() >=2: + cat = "PCT_VAR_ENS" + elif pd.notnull([re.search("percentiles", dim) for dim in array.dims]).sum() == 1: + cat = "PCT_DIM_ENS" + elif pd.notnull([re.search("[Mm]ax|[Mm]in", var) for var in array.data_vars]).sum() >= 2: + cat = "STATS_VAR_ENS" + else: + cat = "NON_ENS_DS" + + elif str(type(array)) == "": + cat = "DA" + else: + raise TypeError('Array is not an Xarray Dataset or DataArray') + return cat -def empty_dict(kwargs): - """Returns empty dictionaries - """ - for k, v in kwargs.items(): - if not v: - kwargs[k] = {} - return kwargs def get_attributes(xr_obj, str): """ @@ -31,23 +50,6 @@ def get_attributes(xr_obj, str): else: raise Exception('Attribute "{0}" not found in "{1}"'.format(str, xr_obj.name)) -def default_attrs(): - """ - Builds a dictionary of default Xarray object attributes to use as plot labels, - using similar behaviour to Xarray.DataArray.plot() - - Args: - xr_obj: Xarray object (DataArray or Dataset) - Returns: - dict of key-value pairs of the format (plot_element:attribute_name) - """ - default = {} - default['title'] = 'long_name' - default['xlabel'] = 'time' - default['ylabel'] = 'standard_name' - default['yunits'] = 'units' - - return default def set_plot_attrs(attr_dict, xr_obj, ax): From d157f8215696849e73d31cc78e0d0db454716fac Mon Sep 17 00:00:00 2001 From: abeaupre Date: Tue, 24 Jan 2023 15:29:29 -0500 Subject: [PATCH 10/25] functionality for var, dim ensembles, regular DS and regular DA --- spirograph/matplotlib/get_example_data.py | 4 +- spirograph/matplotlib/timeseries.py | 83 +++++++++++++---------- spirograph/matplotlib/util_fcts.py | 5 +- 3 files changed, 51 insertions(+), 41 deletions(-) diff --git a/spirograph/matplotlib/get_example_data.py b/spirograph/matplotlib/get_example_data.py index 487f09f7..a2e24400 100644 --- a/spirograph/matplotlib/get_example_data.py +++ b/spirograph/matplotlib/get_example_data.py @@ -9,8 +9,8 @@ da_pct = ds_pct['tx_max_p50'] -data = np.random.rand(4,3) -time = [1,2,3,4] +data = np.random.rand(4,3)*25 + 300 +time = pd.date_range(start ='1960-01-01', end = '2020-01-01', periods = 4) pct = [15,50,95] datest = xr.DataArray(data, coords = [time, pct], dims = ['time', 'percentiles']) diff --git a/spirograph/matplotlib/timeseries.py b/spirograph/matplotlib/timeseries.py index f4bc70e3..f0a2e24e 100644 --- a/spirograph/matplotlib/timeseries.py +++ b/spirograph/matplotlib/timeseries.py @@ -5,21 +5,13 @@ import pandas as pd - # To add # language # logo -# transformer dates de days since... a dates? -# erreur ensemble+DA # xlim -# assigning kwargs to different lines: list?? -# input = dict(nom:ds ou nom:da), avec noms qui deviennent legend # assumer que 'time' est la dimension, et fct qui regarde -# lorsque plusieurs datasets, prendre le 1er -# lorsque dataset n Date: Fri, 27 Jan 2023 14:27:48 -0500 Subject: [PATCH 11/25] minor improvements, scripts to get data --- spirograph/matplotlib/get_example_data.py | 109 +++++++++++++++++++--- spirograph/matplotlib/timeseries.py | 57 ++++++----- spirograph/matplotlib/timeseries_test.py | 22 +++++ spirograph/matplotlib/util_fcts.py | 70 +++++++------- 4 files changed, 179 insertions(+), 79 deletions(-) create mode 100644 spirograph/matplotlib/timeseries_test.py diff --git a/spirograph/matplotlib/get_example_data.py b/spirograph/matplotlib/get_example_data.py index a2e24400..e8cb3410 100644 --- a/spirograph/matplotlib/get_example_data.py +++ b/spirograph/matplotlib/get_example_data.py @@ -1,25 +1,108 @@ import xarray as xr import numpy as np +import pandas as pd +import glob +from xclim import ensembles +import re +# create NetCDFs -# ensemble percentiles Xclim - dataset where each variable represents a percentile with lat, lon, time dims -url_1 = 'https://pavics.ouranos.ca//twitcher/ows/proxy/thredds/dodsC/birdhouse/disk2/cccs_portal/indices/Final/BCCAQv2_CMIP6/tx_max/YS/ssp585/ensemble_percentiles/tx_max_ann_BCCAQ2v2+ANUSPLIN300_historical+ssp585_1950-2100_30ymean_percentiles.nc' -ds_pct_open = xr.open_dataset(url_1, decode_timedelta=False) -ds_pct = ds_pct_open.isel(lon=500, lat=250)[['tx_max_p50', 'tx_max_p10', 'tx_max_p90']] -da_pct = ds_pct['tx_max_p50'] +## rcp4.5, 2015, 3 models +ens2015_rcp45 = glob.glob('/scen3/scenario/netcdf/ouranos/cb-oura-1.0/tasmax_day_*_rcp45_*_2015.nc') +tasmax_rcp45_2015_1 = ensembles.create_ensemble(ens2015_rcp45[3:6]) +tasmax_rcp45_2015_1_stats = ensembles.ensemble_mean_std_max_min(tasmax_rcp45_2015_1) +tasmax_rcp45_2015_1_perc = ensembles.ensemble_percentiles(tasmax_rcp45_2015_1, values=[15, 50, 85], split=False) -data = np.random.rand(4,3)*25 + 300 -time = pd.date_range(start ='1960-01-01', end = '2020-01-01', periods = 4) -pct = [15,50,95] -datest = xr.DataArray(data, coords = [time, pct], dims = ['time', 'percentiles']) +tasmax_rcp45_2015_1_stats.to_netcdf(path='/exec/abeaupre/Projects/spirograph/test_data/tasmax_rcp45_2015_1_stats.nc') +tasmax_rcp45_2015_1_perc.to_netcdf(path='/exec/abeaupre/Projects/spirograph/test_data/tasmax_rcp45_2015_1_perc.nc') + +## rcp4.5, 2015, 3 other models +ens2015_rcp45 = glob.glob('/scen3/scenario/netcdf/ouranos/cb-oura-1.0/tasmax_day_*_rcp45_*_2015.nc') +tasmax_rcp45_2015_2 = ensembles.create_ensemble(ens2015_rcp45[0:3]) + +tasmax_rcp45_2015_2_stats = ensembles.ensemble_mean_std_max_min(tasmax_rcp45_2015_2) +tasmax_rcp45_2015_2_perc = ensembles.ensemble_percentiles(tasmax_rcp45_2015_2, values=[15, 50, 85], split=False) + +tasmax_rcp45_2015_2_stats.to_netcdf(path='/exec/abeaupre/Projects/spirograph/test_data/tasmax_rcp45_2015_2_stats.nc') +tasmax_rcp45_2015_2_perc.to_netcdf(path='/exec/abeaupre/Projects/spirograph/test_data/tasmax_rcp45_2015_2_perc.nc') + +## rcp8.5, 2015, 3 other models +ens2015_rcp85 = glob.glob('/scen3/scenario/netcdf/ouranos/cb-oura-1.0/tasmax_day_*_rcp85_*_2015.nc') +tasmax_rcp85_2015_1 = ensembles.create_ensemble(ens2015_rcp85[3:6]) + +tasmax_rcp85_2015_1_stats = ensembles.ensemble_mean_std_max_min(tasmax_rcp85_2015_1) +tasmax_rcp85_2015_1_perc = ensembles.ensemble_percentiles(tasmax_rcp85_2015_1, values=[15, 50, 85], split=False) + +tasmax_rcp85_2015_1_stats.to_netcdf(path='/exec/abeaupre/Projects/spirograph/test_data/tasmax_rcp85_2015_1_stats.nc') +tasmax_rcp85_2015_1_perc.to_netcdf(path='/exec/abeaupre/Projects/spirograph/test_data/tasmax_rcp85_2015_1_perc.nc') + + +## rcp8.5, 2015, 3 other models +ens2015_rcp85 = glob.glob('/scen3/scenario/netcdf/ouranos/cb-oura-1.0/tasmax_day_*_rcp85_*_2015.nc') +tasmax_rcp85_2015_2 = ensembles.create_ensemble(ens2015_rcp85[0:3]) + +tasmax_rcp85_2015_2_stats = ensembles.ensemble_mean_std_max_min(tasmax_rcp85_2015_2) +tasmax_rcp85_2015_2_perc = ensembles.ensemble_percentiles(tasmax_rcp85_2015_2, values=[15, 50, 85], split=False) + +tasmax_rcp85_2015_2_stats.to_netcdf(path='/exec/abeaupre/Projects/spirograph/test_data/tasmax_rcp85_2015_2_stats.nc') +tasmax_rcp85_2015_2_perc.to_netcdf(path='/exec/abeaupre/Projects/spirograph/test_data/tasmax_rcp85_2015_2_perc.nc') + +## rcp4.5, 2012, 3 models +ens2012_rcp85 = glob.glob('/scen3/scenario/netcdf/ouranos/cb-oura-1.0/tasmax_day_*_rcp85_*_2012.nc') +tasmax_rcp85_2012_1 = ensembles.create_ensemble(ens2012_rcp85[5:8]) -# DataArray simple, tasmax across time at a given point in space -url_2 = 'https://pavics.ouranos.ca/twitcher/ows/proxy/thredds/dodsC/datasets/simulations/bias_adjusted/cmip6/pcic/CanDCS-U6/day_BCCAQv2+ANUSPLIN300_UKESM1-0-LL_historical+ssp585_r1i1p1f2_gn_1950-2100.ncml' -ds_var_open = xr.open_dataset(url_2, decode_timedelta= False) -ds_var = ds_var_open.isel(lon = 500, lat = 250) +tasmax_rcp85_2012_1_stats = ensembles.ensemble_mean_std_max_min(tasmax_rcp85_2012_1) +tasmax_rcp85_2012_1_perc = ensembles.ensemble_percentiles(tasmax_rcp85_2012_1, values=[15, 50, 85], split=False) +tasmax_rcp85_2012_1_stats.to_netcdf(path='/exec/abeaupre/Projects/spirograph/test_data/tasmax_rcp85_2012_1_stats.nc') +tasmax_rcp85_2012_1_perc.to_netcdf(path='/exec/abeaupre/Projects/spirograph/test_data/tasmax_rcp85_2012_1_perc.nc') +# import and process +def output_ds(paths): + + target_lat = 45.5 + target_lon = -73.6 + + datasets = {} + + for path in paths: + if re.search("_stats", path): + open_ds = xr.open_dataset(path, decode_timedelta=False) + var_ds = open_ds[['tasmax_mean', 'tasmax_min', 'tasmax_max']] + elif re.search("_perc", path): + open_ds = xr.open_dataset(path, decode_timedelta=False) + var_ds = open_ds.drop_dims('ts')['tasmax'] + else: + print(path, ' not _stats or _perc') + continue + + loc_ds = var_ds.sel(lat = target_lat, lon = target_lon, method = 'nearest').\ + convert_calendar('standard') + datasets[path.split(sep = '/')[-1]] = loc_ds + return datasets + +paths = glob.glob('/exec/abeaupre/Projects/spirograph/test_data/tasmax*.nc') + +datasets = output_ds(paths) + +# Other datasets +## ensemble percentiles (pct in variables) +url_1 = 'https://pavics.ouranos.ca//twitcher/ows/proxy/thredds/dodsC/birdhouse/disk2/cccs_portal/indices/Final/BCCAQv2_CMIP6/tx_max/YS/ssp585/ensemble_percentiles/tx_max_ann_BCCAQ2v2+ANUSPLIN300_historical+ssp585_1950-2100_30ymean_percentiles.nc' +ds_pct_open = xr.open_dataset(url_1, decode_timedelta=False) + +ds_pct_1 = ds_pct_open.isel(lon=500, lat=250)[['tx_max_p50', 'tx_max_p10', 'tx_max_p90']] +da_pct_1 = ds_pct_1['tx_max_p50'] + +## randomly-generated ensemble percentiles (pct in dims). No attributes +data = np.random.rand(4,3)*25 + 300 +time = pd.date_range(start ='1960-01-01', end = '2020-01-01', periods = 4) +pct = [15,50,95] + +da_pct_rand = xr.DataArray(data, coords = [time, pct], dims = ['time', 'percentiles']) +attr_list = ['long_name','time','standard_name','units'] +for a in attr_list: + da_pct_rand.attrs[a] = 'default diff --git a/spirograph/matplotlib/timeseries.py b/spirograph/matplotlib/timeseries.py index f0a2e24e..367db889 100644 --- a/spirograph/matplotlib/timeseries.py +++ b/spirograph/matplotlib/timeseries.py @@ -9,12 +9,23 @@ # language # logo # xlim -# assumer que 'time' est la dimension, et fct qui regarde +# fct qui s'assure que 'time' est une dimension # variables superflues? -# FIX when PCT_DIM_ENS, label for each line +# FIX when PCT_DIM_ENS, label for each line?? +# FIX label used twice +# ensemble percentiles is a dataarray? it's a dataset +# show lat,lon +# show percentiles? +# option for legend placement +#CHECK: if data not a dict, line_kw not a dict of dicts +#Exceptions, no data, data is all nans +# cftime conversion? +# xlabel not showing up +# ylabel at end of line rather than legend -def line_ts(data, ensemble = False, ax=None, use_attrs=None, sub_kw=None, line_kw=None): + +def line_ts(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None): """ Plots time series from 1D dataframe or dataset Args: @@ -31,7 +42,8 @@ def line_ts(data, ensemble = False, ax=None, use_attrs=None, sub_kw=None, line_k # if only one data input, insert in dict non_dict_data = False if type(data) != dict: - data = {'data': data} + data = {'data_1': data} + line_kw = {'data_1': empty_dict(line_kw)} non_dict_data = True # set default kwargs and add/replace with user inputs, if provided @@ -40,14 +52,16 @@ def line_ts(data, ensemble = False, ax=None, use_attrs=None, sub_kw=None, line_k 'ylabel': 'standard_name', 'yunits': 'units'} plot_sub_kw = {} - plot_line_kw = { name: {} for name in data.keys() } + if non_dict_data is True: + plot_line_kw = {} + else: + plot_line_kw = {name: {} for name in data.keys()} for user_dict, attr_dict in zip([use_attrs, sub_kw, line_kw], [plot_attrs, plot_sub_kw, plot_line_kw]): if user_dict: - for k,v in user_dict.items(): + for k, v in user_dict.items(): attr_dict[k] = v - kwargs = {'sub_kw': plot_sub_kw, 'line_kw': plot_line_kw} # set fig, ax if not provided @@ -62,13 +76,13 @@ def line_ts(data, ensemble = False, ax=None, use_attrs=None, sub_kw=None, line_k for name, arr in data.items(): - if array_categ[name] in ['PCT_VAR_ENS', 'STATS_VAR_ENS', 'PCT_DIM_ENS']: + if array_categ[name] in ['PCT_VAR_ENS', 'STATS_VAR_ENS', 'PCT_DIM_ENS_DA']: # extract each line from the datasets array_data = {} - if array_categ[name] == 'PCT_DIM_ENS': + if array_categ[name] == 'PCT_DIM_ENS_DA': for pct in arr.percentiles: - array_data[pct] = arr.sel(percentiles=int(pct)) + array_data[str(int(pct))] = arr.sel(percentiles=int(pct)) else: for k, v in arr.data_vars.items(): array_data[k] = v @@ -91,14 +105,14 @@ def line_ts(data, ensemble = False, ax=None, use_attrs=None, sub_kw=None, line_k for k, sub_arr in arr.data_vars.items(): sub_name = name + "_" + sub_arr.name # creates plot label - ax.plot(sub_arr['time'], sub_arr.values, **kwargs['line_kw'][name], label = sub_name) + ax.plot(sub_arr['time'], sub_arr.values, **kwargs['line_kw'][name], label=sub_name) else: # should be DataArray - ax.plot(arr['time'], arr.values, **kwargs['line_kw'][name], label = name) + ax.plot(arr['time'], arr.values, **kwargs['line_kw'][name], label=name) - #add/modify plot elements according to the first entry + #add/modify plot elements according to the first entry. set_plot_attrs(plot_attrs, list(data.values())[0], ax) if non_dict_data is False: @@ -107,23 +121,6 @@ def line_ts(data, ensemble = False, ax=None, use_attrs=None, sub_kw=None, line_k return ax -# test - -## simple DataArray, labeled or unlabeled -line_ts(da_pct, line_kw = {'color': 'red'}) -line_ts({'My data': da_pct}) - -## simple Dataset ensemble (variables) -line_ts(ds_pct, use_attrs= {'title': 'ccdp_name'}) -line_ts({'My other data':ds_pct}, use_attrs= {'title': 'ccdp_name'}) - -## simple Dataset ensemble (pct) -line_ts(datest) -line_ts({'My random data': datest}) - -## all together -line_ts({'DataArray':da_pct, 'Var Ensemble': ds_pct, 'other': datest}, - line_kw = {'DataArray':{'color': 'purple'}, 'Var Ensemble': {'color': 'brown'}}) diff --git a/spirograph/matplotlib/timeseries_test.py b/spirograph/matplotlib/timeseries_test.py new file mode 100644 index 00000000..9c231473 --- /dev/null +++ b/spirograph/matplotlib/timeseries_test.py @@ -0,0 +1,22 @@ + +import matplotlib as mpl +mpl.use("Qt5Agg") + + + +# test + +## simple DataArray, labeled or unlabeled +line_ts(da_pct_1, line_kw={'color': 'red'}) +line_ts({'My data': da_pct_1}, line_kw={'My data':{'color': 'red'}}) + +## simple Dataset ensemble (variables) +line_ts(ds_stats_2015) +line_ts({'2015 daily rcp4.5 stats': ds_stats_2015}, line_kw = {'2015 daily rcp4.5 stats':{'color':'purple'}}) + +## simple Dataset ensemble (dims) +line_ts({'2012 daily rcp4.5 percentiles': ds_perc_2012}) + +## all together +line_ts({'DataArray': da_pct, 'Var Ensemble': ds_pct, 'other': datest}, + line_kw={'DataArray': {'color': 'blue'}, 'Var Ensemble': {'color': 'red'}}) diff --git a/spirograph/matplotlib/util_fcts.py b/spirograph/matplotlib/util_fcts.py index 062964dc..8ade9c1c 100644 --- a/spirograph/matplotlib/util_fcts.py +++ b/spirograph/matplotlib/util_fcts.py @@ -2,13 +2,16 @@ import pandas as pd import re - +def empty_dict(param): + if param is None: + param = {} + return param def get_array_categ(array): """Returns an array category PCT_VAR_ENS: ensemble of percentiles stored as variables - PCT_DIM_ENS: ensemble of percentiles stored as dimension coordinates + PCT_DIM_ENS_DA: ensemble of percentiles stored as dimension coordinates, DataArray STATS_VAR_ENS: ensemble of statistics (min, mean, max) stored as variables NON_ENS_DS: dataset of individual lines, not an ensemble DA: DataArray @@ -21,44 +24,54 @@ def get_array_categ(array): if str(type(array)) == "": if pd.notnull([re.search("_p[0-9]{1,2}", var) for var in array.data_vars]).sum() >=2: cat = "PCT_VAR_ENS" - elif pd.notnull([re.search("percentiles", dim) for dim in array.dims]).sum() == 1: - cat = "PCT_DIM_ENS" elif pd.notnull([re.search("[Mm]ax|[Mm]in", var) for var in array.data_vars]).sum() >= 2: cat = "STATS_VAR_ENS" + elif pd.notnull([re.search("percentiles", dim) for dim in array.dims]).sum() == 1: + cat = "PCT_DIM_ENS_DS" ## no support for now else: cat = "NON_ENS_DS" elif str(type(array)) == "": - cat = "DA" + if pd.notnull([re.search("percentiles", dim) for dim in array.dims]).sum() == 1: + cat = "PCT_DIM_ENS_DA" + else: + cat = "DA" else: raise TypeError('Array is not an Xarray Dataset or DataArray') + print('cat: ', cat) return cat -def get_attributes(xr_obj, str): +def get_attributes(xr_obj, strg): """ - Fetches attributes or dims corresponding to keys from Xarray objects + Fetches attributes or dims corresponding to keys from Xarray objects. Looks in + Dataset attributes first, and then looks in DataArray. Args: xr_obj: Xarray DataArray or Dataset str: string corresponding to an attribute key Returns: Xarray attribute value as string """ - if str in xr_obj.attrs: - return xr_obj.attrs[str] - elif str in xr_obj.dims: - return str # special case because DataArray and Dataset dims are not the same types + if strg in xr_obj.attrs: + return xr_obj.attrs[strg] + elif str(type(xr_obj)) == "": + if strg in xr_obj[list(xr_obj.data_vars)[0]].attrs: # DataArray of first variable + return xr_obj[list(xr_obj.data_vars)[0]].attrs[strg] + elif strg in xr_obj.dims: + return strg # special case for 'time' because DataArray and Dataset dims are not the same types else: - raise Exception('Attribute "{0}" not found in "{1}"'.format(str, xr_obj.name)) + print('Attribute "{0}" not found in attributes'.format(strg)) + return '' def set_plot_attrs(attr_dict, xr_obj, ax): """ - Sets plot elements according to DataArray attributes. Uses get_attributes() + Sets plot elements according to Dataset or DataArray attributes. Uses get_attributes() + to check for and return the string. Args: use_attrs (dict): dict containing specified attribute keys - xr_obj: Xarray DataArray + xr_obj: Xarray DataArray. ax: matplotlib axis Returns: matplotlib axis @@ -67,10 +80,13 @@ def set_plot_attrs(attr_dict, xr_obj, ax): if 'title' in attr_dict: ax.set_title(get_attributes(xr_obj, attr_dict['title']), wrap=True) if 'xlabel' in attr_dict: - ax.set_xlabel(get_attributes(xr_obj, attr_dict['xlabel'])) #rotation? + ax.set_xlabel(get_attributes(xr_obj, attr_dict['xlabel'])) # rotation? if 'ylabel' in attr_dict: - ax.set_ylabel(get_attributes(xr_obj, attr_dict['ylabel'])+ ' [' + - get_attributes(xr_obj, attr_dict['yunits'])+ ']') + if 'units' in attr_dict and len(attr_dict['units']) >= 1: # second condition avoids '[]' as label + ax.set_ylabel(get_attributes(xr_obj, attr_dict['ylabel']) + ' [' + + get_attributes(xr_obj, attr_dict['yunits']) + ']') + else: + ax.set_ylabel(get_attributes(xr_obj, attr_dict['ylabel'])) return ax @@ -86,7 +102,7 @@ def sort_lines(array_dict): ref_values = {} sorted_lines = {} for name, arr in array_dict.items(): - ref_values[name] = int(arr[int(len(arr)/2)]) + ref_values[name] = float(arr[int(len(arr)/2)]) # why the first int?? sorted_series = pd.Series(ref_values).sort_values() sorted_lines['lower'] = sorted_series.idxmin() sorted_lines['upper'] = sorted_series.idxmax() @@ -101,21 +117,3 @@ def sort_lines(array_dict): - - -# lnx = np.arange(1,10,1) -# ln1 = lnx + 3 + 1*np.random.rand() -# ln2 = lnx + 10 + 3*np.random.rand() -# #ln3 = lnx + 6 + 2*np.random.rand() -# -# fig, ax = plt.subplots(figsize = (4,3)) -# line_1 = ax.plot(lnx, ln1) -# line_2 = ax.plot(lnx, ln2) -# #line_3 = ax.plot(lnx, ln3) -# -# ax.fill_between(lnx, ln1,ln2, alpha = 0.2, color = 'red') -# #ax.fill_between(lnx, ln1,ln3, alpha = 0.2, color = 'blue') -# plt.show() - - - From d00c469ddbcaa9f196e6fe4d4c81c9729f404412 Mon Sep 17 00:00:00 2001 From: abeaupre Date: Fri, 27 Jan 2023 16:12:01 -0500 Subject: [PATCH 12/25] first functionality tests --- spirograph/matplotlib/get_example_data.py | 6 ++-- spirograph/matplotlib/timeseries.py | 5 +--- spirograph/matplotlib/timeseries_test.py | 36 ++++++++++++++++------- 3 files changed, 31 insertions(+), 16 deletions(-) diff --git a/spirograph/matplotlib/get_example_data.py b/spirograph/matplotlib/get_example_data.py index e8cb3410..3c59076a 100644 --- a/spirograph/matplotlib/get_example_data.py +++ b/spirograph/matplotlib/get_example_data.py @@ -69,6 +69,7 @@ def output_ds(paths): datasets = {} + for path in paths: if re.search("_stats", path): open_ds = xr.open_dataset(path, decode_timedelta=False) @@ -82,7 +83,8 @@ def output_ds(paths): loc_ds = var_ds.sel(lat = target_lat, lon = target_lon, method = 'nearest').\ convert_calendar('standard') - datasets[path.split(sep = '/')[-1]] = loc_ds + datasets[path.split(sep='/')[-1].split(sep='.')[0]] = loc_ds + return datasets paths = glob.glob('/exec/abeaupre/Projects/spirograph/test_data/tasmax*.nc') @@ -105,4 +107,4 @@ def output_ds(paths): da_pct_rand = xr.DataArray(data, coords = [time, pct], dims = ['time', 'percentiles']) attr_list = ['long_name','time','standard_name','units'] for a in attr_list: - da_pct_rand.attrs[a] = 'default + da_pct_rand.attrs[a] = 'default' diff --git a/spirograph/matplotlib/timeseries.py b/spirograph/matplotlib/timeseries.py index 367db889..61657f41 100644 --- a/spirograph/matplotlib/timeseries.py +++ b/spirograph/matplotlib/timeseries.py @@ -10,18 +10,15 @@ # logo # xlim # fct qui s'assure que 'time' est une dimension -# variables superflues? -# FIX when PCT_DIM_ENS, label for each line?? # FIX label used twice -# ensemble percentiles is a dataarray? it's a dataset # show lat,lon # show percentiles? # option for legend placement #CHECK: if data not a dict, line_kw not a dict of dicts #Exceptions, no data, data is all nans # cftime conversion? -# xlabel not showing up # ylabel at end of line rather than legend +#special term in use_attrs to use dict key diff --git a/spirograph/matplotlib/timeseries_test.py b/spirograph/matplotlib/timeseries_test.py index 9c231473..fb53b5c4 100644 --- a/spirograph/matplotlib/timeseries_test.py +++ b/spirograph/matplotlib/timeseries_test.py @@ -1,22 +1,38 @@ import matplotlib as mpl +import matplotlib.pyplot as plt mpl.use("Qt5Agg") # test -## simple DataArray, labeled or unlabeled +## 1 . Basic plot functionality + +## simple DataArray, unlabeled line_ts(da_pct_1, line_kw={'color': 'red'}) -line_ts({'My data': da_pct_1}, line_kw={'My data':{'color': 'red'}}) -## simple Dataset ensemble (variables) -line_ts(ds_stats_2015) -line_ts({'2015 daily rcp4.5 stats': ds_stats_2015}, line_kw = {'2015 daily rcp4.5 stats':{'color':'purple'}}) +## simple DataArray, labeled +line_ts({'My data': da_pct_1}, line_kw={'My data': {'color': 'red'}}) -## simple Dataset ensemble (dims) -line_ts({'2012 daily rcp4.5 percentiles': ds_perc_2012}) +## idem, with no attributes +line_ts({'Random data': da_pct_rand}) -## all together -line_ts({'DataArray': da_pct, 'Var Ensemble': ds_pct, 'other': datest}, - line_kw={'DataArray': {'color': 'blue'}, 'Var Ensemble': {'color': 'red'}}) +## simple Dataset ensemble (variables) +line_ts({'rcp45_2015_1': datasets['tasmax_rcp45_2015_1_stats']}) +line_ts({'rcp45_2015_1': datasets['tasmax_rcp45_2015_1_stats']}, + line_kw={'rcp45_2015_1': {'color': 'purple'}}) + +## simple Dataset ensemble (dims), title override +my_ax = line_ts({'rcp45_2015_1': datasets['tasmax_rcp45_2015_1_perc']}, + line_kw={'rcp45_2015_1': {'color': '#daa520'}}) +my_ax.set_title('The percentiles are in dimensions') + +## one DataArray, one pct Dataset, one stats Dataset +line_ts({'DataArray': datasets['tasmax_rcp45_2015_1_stats']['tasmax_mean'], + 'Dataset_vars': datasets['tasmax_rcp45_2015_2_stats'], + 'Dataset_dims': datasets['tasmax_rcp85_2015_1_perc']}, + line_kw={'DataArray': {'color': '#000080'}, + 'Dataset_vars': {'color': '#ffa500'}, + 'Dataset_dims': {'color':'#468499'} + }) From 7b7eeffdab28d86b880acfa8bfd6ed5bb7474a32 Mon Sep 17 00:00:00 2001 From: abeaupre Date: Tue, 31 Jan 2023 18:24:10 -0500 Subject: [PATCH 13/25] legend inclues patches, many other upgrades --- spirograph/matplotlib/get_example_data.py | 17 ++-- spirograph/matplotlib/timeseries.py | 99 ++++++++++++++--------- spirograph/matplotlib/timeseries_test.py | 9 ++- spirograph/matplotlib/util_fcts.py | 91 +++++++++++++++------ 4 files changed, 149 insertions(+), 67 deletions(-) diff --git a/spirograph/matplotlib/get_example_data.py b/spirograph/matplotlib/get_example_data.py index 3c59076a..96987491 100644 --- a/spirograph/matplotlib/get_example_data.py +++ b/spirograph/matplotlib/get_example_data.py @@ -66,8 +66,9 @@ def output_ds(paths): target_lat = 45.5 target_lon = -73.6 + time_slice = slice(150,200) - datasets = {} + dsets = {} for path in paths: @@ -81,16 +82,22 @@ def output_ds(paths): print(path, ' not _stats or _perc') continue - loc_ds = var_ds.sel(lat = target_lat, lon = target_lon, method = 'nearest').\ - convert_calendar('standard') - datasets[path.split(sep='/')[-1].split(sep='.')[0]] = loc_ds + loc_ds = var_ds.sel(lat=target_lat, lon=target_lon, method='nearest') + #.convert_calendar('standard') + if time_slice: + loc_ds = loc_ds.isel(time=time_slice) + dsets[path.split(sep='/')[-1].split(sep='.')[0]] = loc_ds + + return dsets - return datasets paths = glob.glob('/exec/abeaupre/Projects/spirograph/test_data/tasmax*.nc') datasets = output_ds(paths) + +#datasets['tasmax_rcp45_2015_1_stats'] + # Other datasets ## ensemble percentiles (pct in variables) url_1 = 'https://pavics.ouranos.ca//twitcher/ows/proxy/thredds/dodsC/birdhouse/disk2/cccs_portal/indices/Final/BCCAQv2_CMIP6/tx_max/YS/ssp585/ensemble_percentiles/tx_max_ann_BCCAQ2v2+ANUSPLIN300_historical+ssp585_1950-2100_30ymean_percentiles.nc' diff --git a/spirograph/matplotlib/timeseries.py b/spirograph/matplotlib/timeseries.py index 61657f41..47b16989 100644 --- a/spirograph/matplotlib/timeseries.py +++ b/spirograph/matplotlib/timeseries.py @@ -2,59 +2,68 @@ import numpy as np import xarray as xr import matplotlib.pyplot as plt +from matplotlib.patches import Patch +import warnings import pandas as pd - # To add -# language -# logo -# xlim -# fct qui s'assure que 'time' est une dimension +# language +# logo # FIX label used twice -# show lat,lon # show percentiles? -# option for legend placement -#CHECK: if data not a dict, line_kw not a dict of dicts -#Exceptions, no data, data is all nans -# cftime conversion? # ylabel at end of line rather than legend #special term in use_attrs to use dict key +#change fill_between() edgecolor to the background color - - -def line_ts(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None): +def line_ts(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None, legend='lines'): """ Plots time series from 1D dataframe or dataset Args: - ax: user-specified matplotlib axis data: dictionary of labeled Xarray DataArrays or Datasets + ax: user-specified matplotlib axis use_attrs: dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description') sub_kw: matplotlib subplot kwargs - line_kw : matplotlib or xarray line kwargs + line_kw: matplotlib or xarray line kwargs + legend: 'full' (lines and shading), 'lines' (lines only), 'none' (no legend) Returns: matplotlib axis """ # if only one data input, insert in dict non_dict_data = False + if type(data) != dict: data = {'data_1': data} line_kw = {'data_1': empty_dict(line_kw)} non_dict_data = True - # set default kwargs and add/replace with user inputs, if provided + # basic checks + ## type + for name, arr in data.items(): + if not isinstance(arr, (xr.Dataset, xr.DataArray)): + raise TypeError('data must contain Xarray-type objects') + + ## 'time' dimension and calendar format + data = check_timeindex(data) + + + # set default kwargs plot_attrs = {'title': 'long_name', 'xlabel': 'time', 'ylabel': 'standard_name', 'yunits': 'units'} + plot_sub_kw = {} + if non_dict_data is True: plot_line_kw = {} else: plot_line_kw = {name: {} for name in data.keys()} - for user_dict, attr_dict in zip([use_attrs, sub_kw, line_kw], [plot_attrs, plot_sub_kw, plot_line_kw]): + # add/replace default kwargs with user inputs + for user_dict, attr_dict in zip([use_attrs, sub_kw, line_kw], + [plot_attrs, plot_sub_kw, plot_line_kw]): if user_dict: for k, v in user_dict.items(): attr_dict[k] = v @@ -65,17 +74,18 @@ def line_ts(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None): if not ax: fig, ax = plt.subplots(**kwargs['sub_kw']) - - # build dictionary of array 'categories', which determine how to plot (see get_array_categ fct) + # build dictionary of array 'categories', which determine how to plot data array_categ = {name: get_array_categ(array) for name, array in data.items()} - # get data to plot + # get data and plot + lines_dict = {} for name, arr in data.items(): + # ensembles if array_categ[name] in ['PCT_VAR_ENS', 'STATS_VAR_ENS', 'PCT_DIM_ENS_DA']: - # extract each line from the datasets + # extract each array from the datasets array_data = {} if array_categ[name] == 'PCT_DIM_ENS_DA': for pct in arr.percentiles: @@ -87,34 +97,51 @@ def line_ts(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None): # create a dictionary labeling the middle, upper and lower line sorted_lines = sort_lines(array_data) - # plot the ensemble - line_1 = ax.plot(array_data[sorted_lines['middle']]['time'], + lines_dict[name] = ax.plot(array_data[sorted_lines['middle']]['time'], array_data[sorted_lines['middle']].values, **kwargs['line_kw'][name], label=name) ax.fill_between(array_data[sorted_lines['lower']]['time'], array_data[sorted_lines['lower']].values, array_data[sorted_lines['upper']].values, - color=line_1[0].get_color(), + color=lines_dict[name][0].get_color(), edgecolor='white', alpha=0.2) - elif array_categ[name] in ['NON_ENS_DS']: - - for k, sub_arr in arr.data_vars.items(): - sub_name = name + "_" + sub_arr.name # creates plot label - ax.plot(sub_arr['time'], sub_arr.values, **kwargs['line_kw'][name], label=sub_name) - - - else: # should be DataArray + if legend == 'full': + patch = Patch(facecolor=lines_dict[name][0].get_color(), + edgecolor='white', alpha=0.2, + label="{} - {}".format(sorted_lines['lower'], + sorted_lines['upper'])) - ax.plot(arr['time'], arr.values, **kwargs['line_kw'][name], label=name) - #add/modify plot elements according to the first entry. + # non-ensemble Datasets + elif array_categ[name] in ['NON_ENS_DS']: + for k, sub_arr in arr.data_vars.items(): + sub_name = name + "_" + sub_arr.name # creates plot label + lines_dict[sub_name] = ax.plot(sub_arr['time'], sub_arr.values, + **kwargs['line_kw'][name], label=sub_name + ) + + # non-ensemble DataArrays + else: + lines_dict[name] = ax.plot(arr['time'], arr.values, + **kwargs['line_kw'][name], label=name + ) + + # add/modify plot elements according to the first entry. set_plot_attrs(plot_attrs, list(data.values())[0], ax) - if non_dict_data is False: - ax.legend() + # other plot elements + + if non_dict_data is False and legend is not None: + if legend == 'full': + handles = [v[0] for v in list(lines_dict.values())] # line objects are tuples(?) + handles.append(patch) + ax.legend(handles=handles) + else: + ax.legend() + ax.margins(x=0, y=0.05) return ax diff --git a/spirograph/matplotlib/timeseries_test.py b/spirograph/matplotlib/timeseries_test.py index fb53b5c4..c22d289c 100644 --- a/spirograph/matplotlib/timeseries_test.py +++ b/spirograph/matplotlib/timeseries_test.py @@ -19,7 +19,7 @@ line_ts({'Random data': da_pct_rand}) ## simple Dataset ensemble (variables) -line_ts({'rcp45_2015_1': datasets['tasmax_rcp45_2015_1_stats']}) +line_ts({'rcp45_2015_1': datasets['tasmax_rcp45_2015_1_stats']}, legend = 'full') line_ts({'rcp45_2015_1': datasets['tasmax_rcp45_2015_1_stats']}, line_kw={'rcp45_2015_1': {'color': 'purple'}}) @@ -34,5 +34,10 @@ 'Dataset_dims': datasets['tasmax_rcp85_2015_1_perc']}, line_kw={'DataArray': {'color': '#000080'}, 'Dataset_vars': {'color': '#ffa500'}, - 'Dataset_dims': {'color':'#468499'} + 'Dataset_dims': {'color': '#468499'} }) + +# test with non-ensemble DS + + +ll = line_ts({'rcp45_2015_1': datasets['tasmax_rcp45_2015_1_stats']}, legend = 'full') diff --git a/spirograph/matplotlib/util_fcts.py b/spirograph/matplotlib/util_fcts.py index 8ade9c1c..b3fc533e 100644 --- a/spirograph/matplotlib/util_fcts.py +++ b/spirograph/matplotlib/util_fcts.py @@ -1,12 +1,25 @@ import pandas as pd import re +import warnings +import xarray as xr def empty_dict(param): if param is None: param = {} return param +def check_timeindex(xr_dict): + """ checks if the time index of Xarray objects in a dict is CFtime + and converts to pd.DatetimeIndex if true""" + for name, xr_obj in xr_dict.items(): + if 'time' in xr_obj.dims: + if isinstance(xr_obj.get_index('time'), xr.CFTimeIndex): + conv_obj = xr_obj.convert_calendar('standard', use_cftime=None) + xr_dict[name] = conv_obj + else: + raise ValueError('"time" dimension not found in {}'.format(xr_obj)) + return xr_dict def get_array_categ(array): """Returns an array category @@ -20,8 +33,7 @@ def get_array_categ(array): Returns str """ - - if str(type(array)) == "": + if isinstance(array, xr.Dataset): if pd.notnull([re.search("_p[0-9]{1,2}", var) for var in array.data_vars]).sum() >=2: cat = "PCT_VAR_ENS" elif pd.notnull([re.search("[Mm]ax|[Mm]in", var) for var in array.data_vars]).sum() >= 2: @@ -31,18 +43,18 @@ def get_array_categ(array): else: cat = "NON_ENS_DS" - elif str(type(array)) == "": + elif isinstance(array, xr.DataArray): if pd.notnull([re.search("percentiles", dim) for dim in array.dims]).sum() == 1: cat = "PCT_DIM_ENS_DA" else: cat = "DA" else: raise TypeError('Array is not an Xarray Dataset or DataArray') - print('cat: ', cat) + return cat -def get_attributes(xr_obj, strg): +def get_attributes(strg, xr_obj): """ Fetches attributes or dims corresponding to keys from Xarray objects. Looks in Dataset attributes first, and then looks in DataArray. @@ -54,13 +66,16 @@ def get_attributes(xr_obj, strg): """ if strg in xr_obj.attrs: return xr_obj.attrs[strg] - elif str(type(xr_obj)) == "": - if strg in xr_obj[list(xr_obj.data_vars)[0]].attrs: # DataArray of first variable - return xr_obj[list(xr_obj.data_vars)[0]].attrs[strg] + elif strg in xr_obj.dims: return strg # special case for 'time' because DataArray and Dataset dims are not the same types + + elif isinstance(array, xr.Dataset): + if strg in xr_obj[list(xr_obj.data_vars)[0]].attrs: # DataArray of first variable + return xr_obj[list(xr_obj.data_vars)[0]].attrs[strg] + else: - print('Attribute "{0}" not found in attributes'.format(strg)) + warnings.warn('Attribute "{0}" not found in attributes'.format(strg)) return '' @@ -75,39 +90,67 @@ def set_plot_attrs(attr_dict, xr_obj, ax): ax: matplotlib axis Returns: matplotlib axis - Todo: include lat,lon coordinates in title, add warning if input not in list (e.g. y_label) + """ + # check + for key in attr_dict: + if key not in ['title','xlabel', 'ylabel', 'yunits']: + warnings.warn('Use_attrs element "{}" not supported'.format(key)) + if 'title' in attr_dict: - ax.set_title(get_attributes(xr_obj, attr_dict['title']), wrap=True) + if 'lat' in xr_obj.coords and 'lon' in xr_obj.coords: + ax.set_title(get_attributes(attr_dict['title'], xr_obj) + + ' (lat={:.2f}, lon={:.2f})'.format(float(xr_obj['lat']), + float(xr_obj['lon'])), + wrap=True) + else: + ax.set_title(get_attributes(attr_dict['title'], xr_obj), wrap=True) + if 'xlabel' in attr_dict: - ax.set_xlabel(get_attributes(xr_obj, attr_dict['xlabel'])) # rotation? + ax.set_xlabel(get_attributes(attr_dict['xlabel'], xr_obj)) + if 'ylabel' in attr_dict: if 'units' in attr_dict and len(attr_dict['units']) >= 1: # second condition avoids '[]' as label - ax.set_ylabel(get_attributes(xr_obj, attr_dict['ylabel']) + ' [' + - get_attributes(xr_obj, attr_dict['yunits']) + ']') + ax.set_ylabel(get_attributes(attr_dict['ylabel'], xr_obj) + ' [' + + get_attributes(attr_dict['yunits'], xr_obj) + ']') else: - ax.set_ylabel(get_attributes(xr_obj, attr_dict['ylabel'])) + ax.set_ylabel(get_attributes(attr_dict['ylabel'], xr_obj)) return ax def sort_lines(array_dict): """ - Sorts and labels same-length arrays that plot as parallel lines in x,y space - according to the highest and lowest along the y-axis + Labels arrays as 'middle', 'upper' and 'lower' for ensemble plotting Args: array_dict: dict of arrays. Returns: dict """ - ref_values = {} + if len(array_dict) != 3: + raise ValueError('Ensembles must contain exactly three arrays') + sorted_lines = {} - for name, arr in array_dict.items(): - ref_values[name] = float(arr[int(len(arr)/2)]) # why the first int?? - sorted_series = pd.Series(ref_values).sort_values() - sorted_lines['lower'] = sorted_series.idxmin() - sorted_lines['upper'] = sorted_series.idxmax() - sorted_lines['middle'] = sorted_series.index[int(len(sorted_series)/2 - 0.5)] # -0.5 is + 0.5 - 1, to account for 0-indexing + for name in array_dict.keys(): + if re.search("[0-9]{1,2}$|[Mm]ax$|[Mm]in$|[Mm]ean$", name): + suffix = re.search("[0-9]{1,2}$|[Mm]ax$|[Mm]in$|[Mm]ean$", name).group() + + if suffix.isalpha(): + if suffix in ['max', 'Max']: + sorted_lines['upper'] = name + elif suffix in ['min', 'Min']: + sorted_lines['lower'] = name + elif suffix in ['mean', 'Mean']: + sorted_lines['middle'] = name + elif suffix.isdigit(): + if int(suffix) >= 51: + sorted_lines['upper'] = name + elif int(suffix) <= 49: + sorted_lines['lower'] = name + elif int(suffix) == 50: + sorted_lines['middle'] = name + else: + raise Exception('Arrays names must end in format "_mean" or "_p50" ') return sorted_lines From 4eb742f65a48f5ba0ccfd053e6700a7b0106a5c6 Mon Sep 17 00:00:00 2001 From: abeaupre Date: Wed, 1 Feb 2023 18:05:23 -0500 Subject: [PATCH 14/25] lat,on out of title, label bug fixed, in-plot legend first (buggy) version --- spirograph/matplotlib/get_example_data.py | 2 +- spirograph/matplotlib/timeseries.py | 72 ++++++------ spirograph/matplotlib/timeseries_test.py | 9 +- spirograph/matplotlib/util_fcts.py | 136 ++++++++++++++++------ 4 files changed, 142 insertions(+), 77 deletions(-) diff --git a/spirograph/matplotlib/get_example_data.py b/spirograph/matplotlib/get_example_data.py index 96987491..34e12f21 100644 --- a/spirograph/matplotlib/get_example_data.py +++ b/spirograph/matplotlib/get_example_data.py @@ -66,7 +66,7 @@ def output_ds(paths): target_lat = 45.5 target_lon = -73.6 - time_slice = slice(150,200) + time_slice = slice(160,260) dsets = {} diff --git a/spirograph/matplotlib/timeseries.py b/spirograph/matplotlib/timeseries.py index 47b16989..8a1e279d 100644 --- a/spirograph/matplotlib/timeseries.py +++ b/spirograph/matplotlib/timeseries.py @@ -1,21 +1,15 @@ -import numpy as np import xarray as xr import matplotlib.pyplot as plt from matplotlib.patches import Patch -import warnings -import pandas as pd # To add -# language +# translation to fr # logo -# FIX label used twice -# show percentiles? +# FIX full legend when multiple ensembles # ylabel at end of line rather than legend -#special term in use_attrs to use dict key -#change fill_between() edgecolor to the background color -def line_ts(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None, legend='lines'): +def line_ts(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None, legend='lines', show_coords = True): """ Plots time series from 1D dataframe or dataset Args: @@ -25,7 +19,8 @@ def line_ts(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None, legend='li to a DataArray attribute (value, e.g. 'Description') sub_kw: matplotlib subplot kwargs line_kw: matplotlib or xarray line kwargs - legend: 'full' (lines and shading), 'lines' (lines only), 'none' (no legend) + legend: 'full' (lines and shading), 'lines' (lines only), + 'in_plot' (self-expl.), 'none' (no legend) Returns: matplotlib axis """ @@ -34,8 +29,8 @@ def line_ts(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None, legend='li non_dict_data = False if type(data) != dict: - data = {'data_1': data} - line_kw = {'data_1': empty_dict(line_kw)} + data = {'no_label': data} + line_kw = {'no_label': empty_dict(line_kw)} non_dict_data = True # basic checks @@ -78,10 +73,14 @@ def line_ts(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None, legend='li array_categ = {name: get_array_categ(array) for name, array in data.items()} # get data and plot - lines_dict = {} + lines_dict = {} # created to facilitate accessing line properties later for name, arr in data.items(): + # add name in line kwargs if not there, to avoid error due to double 'label' args in plot() + if 'label' not in kwargs['line_kw'][name]: + kwargs['line_kw'][name]['label'] = name + # ensembles if array_categ[name] in ['PCT_VAR_ENS', 'STATS_VAR_ENS', 'PCT_DIM_ENS_DA']: @@ -97,51 +96,52 @@ def line_ts(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None, legend='li # create a dictionary labeling the middle, upper and lower line sorted_lines = sort_lines(array_data) - # plot the ensemble + # plot line lines_dict[name] = ax.plot(array_data[sorted_lines['middle']]['time'], - array_data[sorted_lines['middle']].values, **kwargs['line_kw'][name], label=name) + array_data[sorted_lines['middle']].values, + **kwargs['line_kw'][name]) + + # plot shading + if array_categ[name] in ['PCT_VAR_ENS', 'PCT_DIM_ENS_DA']: + fill_between_label = "{}th-{}th percentiles".format(get_suffix(sorted_lines['lower']), + get_suffix(sorted_lines['upper'])) + if array_categ[name] in ['STATS_VAR_ENS']: + fill_between_label = "min-max range" + if legend != 'full': + fill_between_label = None ax.fill_between(array_data[sorted_lines['lower']]['time'], array_data[sorted_lines['lower']].values, array_data[sorted_lines['upper']].values, color=lines_dict[name][0].get_color(), - edgecolor='white', alpha=0.2) - - if legend == 'full': - patch = Patch(facecolor=lines_dict[name][0].get_color(), - edgecolor='white', alpha=0.2, - label="{} - {}".format(sorted_lines['lower'], - sorted_lines['upper'])) - + linewidth = 0.0, alpha=0.2, label=fill_between_label) # non-ensemble Datasets elif array_categ[name] in ['NON_ENS_DS']: for k, sub_arr in arr.data_vars.items(): - sub_name = name + "_" + sub_arr.name # creates plot label - lines_dict[sub_name] = ax.plot(sub_arr['time'], sub_arr.values, - **kwargs['line_kw'][name], label=sub_name - ) + sub_name = kwargs['line_kw'][name]['label'] + "_" + sub_arr.name + lines_dict[sub_name] = ax.plot(sub_arr['time'], sub_arr.values,**kwargs['line_kw'][name]) # non-ensemble DataArrays else: - lines_dict[name] = ax.plot(arr['time'], arr.values, - **kwargs['line_kw'][name], label=name - ) + lines_dict[name] = ax.plot(arr['time'], arr.values,**kwargs['line_kw'][name]) # add/modify plot elements according to the first entry. set_plot_attrs(plot_attrs, list(data.values())[0], ax) - # other plot elements + # other plot elements + + ax.margins(x=0, y=0.05) + + if show_coords: + plot_coords(ax, list(data.values())[0]) if non_dict_data is False and legend is not None: - if legend == 'full': - handles = [v[0] for v in list(lines_dict.values())] # line objects are tuples(?) - handles.append(patch) - ax.legend(handles=handles) + if legend == 'in_plot': + in_plot_legend(ax) else: ax.legend() - ax.margins(x=0, y=0.05) return ax diff --git a/spirograph/matplotlib/timeseries_test.py b/spirograph/matplotlib/timeseries_test.py index c22d289c..c89f60b0 100644 --- a/spirograph/matplotlib/timeseries_test.py +++ b/spirograph/matplotlib/timeseries_test.py @@ -19,13 +19,14 @@ line_ts({'Random data': da_pct_rand}) ## simple Dataset ensemble (variables) -line_ts({'rcp45_2015_1': datasets['tasmax_rcp45_2015_1_stats']}, legend = 'full') +line_ts({'rcp45_2015_1': datasets['tasmax_rcp45_2015_1_stats']}, legend = 'full', show_coords = True) + line_ts({'rcp45_2015_1': datasets['tasmax_rcp45_2015_1_stats']}, line_kw={'rcp45_2015_1': {'color': 'purple'}}) ## simple Dataset ensemble (dims), title override my_ax = line_ts({'rcp45_2015_1': datasets['tasmax_rcp45_2015_1_perc']}, - line_kw={'rcp45_2015_1': {'color': '#daa520'}}) + line_kw={'rcp45_2015_1': {'color': '#daa520'}}, legend = 'full') my_ax.set_title('The percentiles are in dimensions') ## one DataArray, one pct Dataset, one stats Dataset @@ -35,9 +36,9 @@ line_kw={'DataArray': {'color': '#000080'}, 'Dataset_vars': {'color': '#ffa500'}, 'Dataset_dims': {'color': '#468499'} - }) + }, legend = 'full') # test with non-ensemble DS +# test with pct_dim_ens_ds -ll = line_ts({'rcp45_2015_1': datasets['tasmax_rcp45_2015_1_stats']}, legend = 'full') diff --git a/spirograph/matplotlib/util_fcts.py b/spirograph/matplotlib/util_fcts.py index b3fc533e..fb294f62 100644 --- a/spirograph/matplotlib/util_fcts.py +++ b/spirograph/matplotlib/util_fcts.py @@ -54,7 +54,7 @@ def get_array_categ(array): return cat -def get_attributes(strg, xr_obj): +def get_attributes(string, xr_obj): """ Fetches attributes or dims corresponding to keys from Xarray objects. Looks in Dataset attributes first, and then looks in DataArray. @@ -64,18 +64,18 @@ def get_attributes(strg, xr_obj): Returns: Xarray attribute value as string """ - if strg in xr_obj.attrs: - return xr_obj.attrs[strg] + if string in xr_obj.attrs: + return xr_obj.attrs[string] - elif strg in xr_obj.dims: - return strg # special case for 'time' because DataArray and Dataset dims are not the same types + elif string in xr_obj.dims: + return string # special case for 'time' because DataArray and Dataset dims are not the same types - elif isinstance(array, xr.Dataset): - if strg in xr_obj[list(xr_obj.data_vars)[0]].attrs: # DataArray of first variable - return xr_obj[list(xr_obj.data_vars)[0]].attrs[strg] + elif isinstance(xr_obj, xr.Dataset): + if string in xr_obj[list(xr_obj.data_vars)[0]].attrs: # DataArray of first variable + return xr_obj[list(xr_obj.data_vars)[0]].attrs[string] else: - warnings.warn('Attribute "{0}" not found in attributes'.format(strg)) + warnings.warn('Attribute "{0}" not found in attributes'.format(string)) return '' @@ -98,25 +98,27 @@ def set_plot_attrs(attr_dict, xr_obj, ax): warnings.warn('Use_attrs element "{}" not supported'.format(key)) if 'title' in attr_dict: - if 'lat' in xr_obj.coords and 'lon' in xr_obj.coords: - ax.set_title(get_attributes(attr_dict['title'], xr_obj) + - ' (lat={:.2f}, lon={:.2f})'.format(float(xr_obj['lat']), - float(xr_obj['lon'])), - wrap=True) - else: - ax.set_title(get_attributes(attr_dict['title'], xr_obj), wrap=True) + ax.set_title(get_attributes(attr_dict['title'], xr_obj), wrap=True) if 'xlabel' in attr_dict: ax.set_xlabel(get_attributes(attr_dict['xlabel'], xr_obj)) if 'ylabel' in attr_dict: - if 'units' in attr_dict and len(attr_dict['units']) >= 1: # second condition avoids '[]' as label + if 'yunits' in attr_dict and len(attr_dict['yunits']) >= 1: # second condition avoids '[]' as label ax.set_ylabel(get_attributes(attr_dict['ylabel'], xr_obj) + ' [' + get_attributes(attr_dict['yunits'], xr_obj) + ']') else: ax.set_ylabel(get_attributes(attr_dict['ylabel'], xr_obj)) return ax +def get_suffix(string): + """ get suffix of typical Xclim variable names""" + if re.search("[0-9]{1,2}$|_[Mm]ax$|_[Mm]in$|_[Mm]ean$", string): + suffix = re.search("[0-9]{1,2}$|[Mm]ax$|[Mm]in$|[Mm]ean$", string).group() + return suffix + else: + raise Exception('No suffix found in {}'.format(string)) + def sort_lines(array_dict): """ @@ -132,31 +134,93 @@ def sort_lines(array_dict): sorted_lines = {} for name in array_dict.keys(): - if re.search("[0-9]{1,2}$|[Mm]ax$|[Mm]in$|[Mm]ean$", name): - suffix = re.search("[0-9]{1,2}$|[Mm]ax$|[Mm]in$|[Mm]ean$", name).group() - - if suffix.isalpha(): - if suffix in ['max', 'Max']: - sorted_lines['upper'] = name - elif suffix in ['min', 'Min']: - sorted_lines['lower'] = name - elif suffix in ['mean', 'Mean']: - sorted_lines['middle'] = name - elif suffix.isdigit(): - if int(suffix) >= 51: - sorted_lines['upper'] = name - elif int(suffix) <= 49: - sorted_lines['lower'] = name - elif int(suffix) == 50: - sorted_lines['middle'] = name - else: - raise Exception('Arrays names must end in format "_mean" or "_p50" ') + suffix = get_suffix(name) + + if suffix.isalpha(): + if suffix in ['max', 'Max']: + sorted_lines['upper'] = name + elif suffix in ['min', 'Min']: + sorted_lines['lower'] = name + elif suffix in ['mean', 'Mean']: + sorted_lines['middle'] = name + elif suffix.isdigit(): + if int(suffix) >= 51: + sorted_lines['upper'] = name + elif int(suffix) <= 49: + sorted_lines['lower'] = name + elif int(suffix) == 50: + sorted_lines['middle'] = name + else: + raise Exception('Arrays names must end in format "_mean" or "_p50" ') return sorted_lines +def plot_coords(ax, xr_obj): + if 'lat' in xr_obj.coords and 'lon' in xr_obj.coords: + text = 'lat={:.2f}, lon={:.2f}'.format(float(xr_obj['lat']), + float(xr_obj['lon'])) + ax.text(0.99, 0.01, text, transform=ax.transAxes, ha = 'right', va = 'bottom') + else: + raise Exception('show_coords set to True, but no coordonates found in {}.coords'.format(xr_obj)) + + return ax + + + +def in_plot_legend(ax, xlim_factor=0.08, label_gap=0.03, out = False): + """ + Draws line labels at the end of each line + Args: + xlim_factor: float + percentage of the x-axis length to add at the far right of the plot + label_gap: float + percentage of the x-axis length to add as a gap between line and label + + Returns: + matplotlib axis + """ + #create extra space + + + init_xlim = ax.get_xlim() + ax.set_xlim(xmin=init_xlim[0], + xmax=init_xlim[1] + (init_xlim[1] * xlim_factor)) + #get legend and plot + + handles, labels = ax.get_legend_handles_labels() + for handle, label in zip(handles, labels): + last_pt = (handle.get_xdata()[-1], handle.get_ydata()[-1]) + last_pt_dsp = ax.transData.transform(last_pt) + last_pt_ax = ax.transAxes.inverted().transform(last_pt_dsp) + last_x = last_pt_ax[0] + last_y = last_pt_ax[1] + color = handle.get_color() + ls = handle.get_linestyle() + + if out is False: + ax.text(last_x + (label_gap * last_x), last_y, label, + ha='left', va='center', color=color, transform=ax.transAxes) + + + if out is True: + ax.text(1.05, last_y, label, ha='left', va='center', color=color, transform=ax.transAxes) + ax.plot([1, 1.2], [last_y, last_y], ls=ls, color=color, transform=ax.transAxes) + #ax.axhline(y=last_y, xmin=last_x, xmax = 1,ls="-", color=color) + + return ax +fig, ax = plt.subplots() +ax.plot([1, 2, 3], [5, 8, 9], label='1st LABEL') +ax.plot([1, 2, 3], [8, 2, 4], label='2nd LABEL') +in_plot_legend(ax, out = False) +fig, [ax1, ax2] = plt.subplots(1,2) +for ax in [ax1,ax2]: + ax.plot([1,2,3],[5,8,9], label = 'YASS') + ax.plot([1,2,3],[8,2,4], label = 'YES') +in_plot_legend(ax1) +in_plot_legend(ax2, out=True) From 0d556ed5c3160b1a37348c50655dd0d2ac1113be Mon Sep 17 00:00:00 2001 From: abeaupre Date: Mon, 6 Feb 2023 11:14:38 -0500 Subject: [PATCH 15/25] split_legend func, docstrings updated, non-ensemble DS legend fixed --- spirograph/matplotlib/get_example_data.py | 32 +++- spirograph/matplotlib/timeseries.py | 79 ++++++---- spirograph/matplotlib/timeseries_test.py | 28 ++-- spirograph/matplotlib/util_fcts.py | 182 +++++++++++++--------- 4 files changed, 194 insertions(+), 127 deletions(-) diff --git a/spirograph/matplotlib/get_example_data.py b/spirograph/matplotlib/get_example_data.py index 34e12f21..3eb70048 100644 --- a/spirograph/matplotlib/get_example_data.py +++ b/spirograph/matplotlib/get_example_data.py @@ -21,8 +21,8 @@ ens2015_rcp45 = glob.glob('/scen3/scenario/netcdf/ouranos/cb-oura-1.0/tasmax_day_*_rcp45_*_2015.nc') tasmax_rcp45_2015_2 = ensembles.create_ensemble(ens2015_rcp45[0:3]) -tasmax_rcp45_2015_2_stats = ensembles.ensemble_mean_std_max_min(tasmax_rcp45_2015_2) -tasmax_rcp45_2015_2_perc = ensembles.ensemble_percentiles(tasmax_rcp45_2015_2, values=[15, 50, 85], split=False) +tasmax_rcp45_2015_2_stats = ensembles.ensemble_mean_std_max_min(tasmax_rcp45_2015_2).sel(lat =slice(65,40), lon = slice(-90,-55)) +tasmax_rcp45_2015_2_perc = ensembles.ensemble_percentiles(tasmax_rcp45_2015_2, values=[15, 50, 85], split=False).sel(lat =slice(65,40), lon = slice(-90,-55)) tasmax_rcp45_2015_2_stats.to_netcdf(path='/exec/abeaupre/Projects/spirograph/test_data/tasmax_rcp45_2015_2_stats.nc') tasmax_rcp45_2015_2_perc.to_netcdf(path='/exec/abeaupre/Projects/spirograph/test_data/tasmax_rcp45_2015_2_perc.nc') @@ -31,8 +31,8 @@ ens2015_rcp85 = glob.glob('/scen3/scenario/netcdf/ouranos/cb-oura-1.0/tasmax_day_*_rcp85_*_2015.nc') tasmax_rcp85_2015_1 = ensembles.create_ensemble(ens2015_rcp85[3:6]) -tasmax_rcp85_2015_1_stats = ensembles.ensemble_mean_std_max_min(tasmax_rcp85_2015_1) -tasmax_rcp85_2015_1_perc = ensembles.ensemble_percentiles(tasmax_rcp85_2015_1, values=[15, 50, 85], split=False) +tasmax_rcp85_2015_1_stats = ensembles.ensemble_mean_std_max_min(tasmax_rcp85_2015_1).sel(lat =slice(65,40), lon = slice(-90,-55)) +tasmax_rcp85_2015_1_perc = ensembles.ensemble_percentiles(tasmax_rcp85_2015_1, values=[15, 50, 85], split=False).sel(lat =slice(65,40), lon = slice(-90,-55)) tasmax_rcp85_2015_1_stats.to_netcdf(path='/exec/abeaupre/Projects/spirograph/test_data/tasmax_rcp85_2015_1_stats.nc') tasmax_rcp85_2015_1_perc.to_netcdf(path='/exec/abeaupre/Projects/spirograph/test_data/tasmax_rcp85_2015_1_perc.nc') @@ -42,8 +42,8 @@ ens2015_rcp85 = glob.glob('/scen3/scenario/netcdf/ouranos/cb-oura-1.0/tasmax_day_*_rcp85_*_2015.nc') tasmax_rcp85_2015_2 = ensembles.create_ensemble(ens2015_rcp85[0:3]) -tasmax_rcp85_2015_2_stats = ensembles.ensemble_mean_std_max_min(tasmax_rcp85_2015_2) -tasmax_rcp85_2015_2_perc = ensembles.ensemble_percentiles(tasmax_rcp85_2015_2, values=[15, 50, 85], split=False) +tasmax_rcp85_2015_2_stats = ensembles.ensemble_mean_std_max_min(tasmax_rcp85_2015_2).sel(lat =slice(65,40), lon = slice(-90,-55)) +tasmax_rcp85_2015_2_perc = ensembles.ensemble_percentiles(tasmax_rcp85_2015_2, values=[15, 50, 85], split=False).sel(lat =slice(65,40), lon = slice(-90,-55)) tasmax_rcp85_2015_2_stats.to_netcdf(path='/exec/abeaupre/Projects/spirograph/test_data/tasmax_rcp85_2015_2_stats.nc') tasmax_rcp85_2015_2_perc.to_netcdf(path='/exec/abeaupre/Projects/spirograph/test_data/tasmax_rcp85_2015_2_perc.nc') @@ -98,6 +98,11 @@ def output_ds(paths): #datasets['tasmax_rcp45_2015_1_stats'] +# make percentile dataset + +ds_perc = xr.Dataset({'rcp45': datasets['tasmax_rcp45_2015_1_perc'], + 'rcp85': datasets['tasmax_rcp85_2015_1_perc']}) + # Other datasets ## ensemble percentiles (pct in variables) url_1 = 'https://pavics.ouranos.ca//twitcher/ows/proxy/thredds/dodsC/birdhouse/disk2/cccs_portal/indices/Final/BCCAQv2_CMIP6/tx_max/YS/ssp585/ensemble_percentiles/tx_max_ann_BCCAQ2v2+ANUSPLIN300_historical+ssp585_1950-2100_30ymean_percentiles.nc' @@ -108,10 +113,23 @@ def output_ds(paths): ## randomly-generated ensemble percentiles (pct in dims). No attributes data = np.random.rand(4,3)*25 + 300 -time = pd.date_range(start ='1960-01-01', end = '2020-01-01', periods = 4) +time = pd.date_range(start='1960-01-01', end='2020-01-01', periods=4) pct = [15,50,95] da_pct_rand = xr.DataArray(data, coords = [time, pct], dims = ['time', 'percentiles']) attr_list = ['long_name','time','standard_name','units'] for a in attr_list: da_pct_rand.attrs[a] = 'default' + +## randomly-generated non-ensemble dataset + +time = pd.date_range(start ='1960-01-01', end = '2020-01-01', periods = 10) +dat_1 = np.random.rand(10) * 20 +dat_2 = np.random.rand(10) * 20 +dat_3 = np.random.rand(10) * 20 + +rand_ds = xr.Dataset(data_vars={'data1': ('time', dat_1), + 'data2': ('time', dat_2), + 'data3': ('time', dat_3)}, + coords={'time': time}, + attrs={'description': 'Randomly generated time-series'}) diff --git a/spirograph/matplotlib/timeseries.py b/spirograph/matplotlib/timeseries.py index 8a1e279d..45aa7bb4 100644 --- a/spirograph/matplotlib/timeseries.py +++ b/spirograph/matplotlib/timeseries.py @@ -1,27 +1,32 @@ - import xarray as xr import matplotlib.pyplot as plt -from matplotlib.patches import Patch # To add # translation to fr # logo -# FIX full legend when multiple ensembles -# ylabel at end of line rather than legend -def line_ts(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None, legend='lines', show_coords = True): +def timeseries(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None, legend='lines', show_coords = True): """ - Plots time series from 1D dataframe or dataset - Args: - data: dictionary of labeled Xarray DataArrays or Datasets - ax: user-specified matplotlib axis - use_attrs: dict linking a plot element (key, e.g. 'title') - to a DataArray attribute (value, e.g. 'Description') - sub_kw: matplotlib subplot kwargs - line_kw: matplotlib or xarray line kwargs - legend: 'full' (lines and shading), 'lines' (lines only), - 'in_plot' (self-expl.), 'none' (no legend) - Returns: + Plots time series from 1D dataframes or datasets + Parameters + __________ + data: dict or Dataset/DataArray + dictionary of labeled Xarray DataArrays or Datasets + ax: matplotlib axis + user-specified matplotlib axis + use_attrs: dict + dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description') + sub_kw: dict + matplotlib subplots kwargs in the format {'param': value} + line_kw: dict + matplotlib or xarray line kwargs in the format {'param': value} + legend: str + 'full' (lines and shading), 'lines' (lines only), 'in_plot' (end of lines), + 'edge' (out of plot), 'none' (no legend) + show_coords: bool + show latitude, longitude coordinates at the bottom right of the figure + Returns + _______ matplotlib axis """ @@ -29,8 +34,8 @@ def line_ts(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None, legend='li non_dict_data = False if type(data) != dict: - data = {'no_label': data} - line_kw = {'no_label': empty_dict(line_kw)} + data = {'_no_label': data} # mpl excludes labels starting with "_" from legend + line_kw = {'_no_label': empty_dict(line_kw)} non_dict_data = True # basic checks @@ -42,7 +47,6 @@ def line_ts(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None, legend='li ## 'time' dimension and calendar format data = check_timeindex(data) - # set default kwargs plot_attrs = {'title': 'long_name', 'xlabel': 'time', @@ -117,35 +121,46 @@ def line_ts(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None, legend='li linewidth = 0.0, alpha=0.2, label=fill_between_label) # non-ensemble Datasets - elif array_categ[name] in ['NON_ENS_DS']: + elif array_categ[name] in ['DS']: for k, sub_arr in arr.data_vars.items(): - sub_name = kwargs['line_kw'][name]['label'] + "_" + sub_arr.name - lines_dict[sub_name] = ax.plot(sub_arr['time'], sub_arr.values,**kwargs['line_kw'][name]) + if non_dict_data is True: + sub_name = sub_arr.name + else: + sub_name = kwargs['line_kw'][name]['label'] + "_" + sub_arr.name + + #put sub_name in line_kwargs to label correctly on plot, store the + # original, and put it back after + store_label = kwargs['line_kw'][name]['label'] + kwargs['line_kw'][name]['label'] = sub_name + lines_dict[sub_name] = ax.plot(sub_arr['time'], sub_arr.values, **kwargs['line_kw'][name]) + kwargs['line_kw'][name]['label'] = store_label + # non-ensemble DataArrays + elif array_categ[name] in ['DA']: + lines_dict[name] = ax.plot(arr['time'], arr.values, **kwargs['line_kw'][name]) + else: - lines_dict[name] = ax.plot(arr['time'], arr.values,**kwargs['line_kw'][name]) + raise Exception('Data structure not supported') # add/modify plot elements according to the first entry. set_plot_attrs(plot_attrs, list(data.values())[0], ax) - # other plot elements + # other plot elements (check overlap with Stylesheet!) ax.margins(x=0, y=0.05) + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) if show_coords: plot_coords(ax, list(data.values())[0]) - if non_dict_data is False and legend is not None: + if legend is not None: # non_dict_data is False and if legend == 'in_plot': - in_plot_legend(ax) + split_legend(ax, out=False) + elif legend == 'edge': + split_legend(ax, out=True) else: ax.legend() - return ax - - - - - diff --git a/spirograph/matplotlib/timeseries_test.py b/spirograph/matplotlib/timeseries_test.py index c89f60b0..1583de29 100644 --- a/spirograph/matplotlib/timeseries_test.py +++ b/spirograph/matplotlib/timeseries_test.py @@ -2,6 +2,7 @@ import matplotlib as mpl import matplotlib.pyplot as plt mpl.use("Qt5Agg") +#mpl.style.use('dark_background') # mpl.style.available @@ -10,35 +11,38 @@ ## 1 . Basic plot functionality ## simple DataArray, unlabeled -line_ts(da_pct_1, line_kw={'color': 'red'}) +timeseries(da_pct_1, line_kw={'color': 'red'}) ## simple DataArray, labeled -line_ts({'My data': da_pct_1}, line_kw={'My data': {'color': 'red'}}) +timeseries({'My data': da_pct_1}, line_kw={'My data': {'color': 'red'}}) ## idem, with no attributes -line_ts({'Random data': da_pct_rand}) +timeseries({'Random data': da_pct_rand}) ## simple Dataset ensemble (variables) -line_ts({'rcp45_2015_1': datasets['tasmax_rcp45_2015_1_stats']}, legend = 'full', show_coords = True) +timeseries({'rcp45_2015_1': datasets['tasmax_rcp45_2015_1_stats']}, legend = 'full', show_coords = True) -line_ts({'rcp45_2015_1': datasets['tasmax_rcp45_2015_1_stats']}, +timeseries({'rcp45_2015_1': datasets['tasmax_rcp45_2015_1_stats']}, line_kw={'rcp45_2015_1': {'color': 'purple'}}) ## simple Dataset ensemble (dims), title override -my_ax = line_ts({'rcp45_2015_1': datasets['tasmax_rcp45_2015_1_perc']}, +my_ax = timeseries({'rcp45_2015_1': datasets['tasmax_rcp45_2015_1_perc']}, line_kw={'rcp45_2015_1': {'color': '#daa520'}}, legend = 'full') my_ax.set_title('The percentiles are in dimensions') ## one DataArray, one pct Dataset, one stats Dataset -line_ts({'DataArray': datasets['tasmax_rcp45_2015_1_stats']['tasmax_mean'], - 'Dataset_vars': datasets['tasmax_rcp45_2015_2_stats'], - 'Dataset_dims': datasets['tasmax_rcp85_2015_1_perc']}, - line_kw={'DataArray': {'color': '#000080'}, +timeseries({'DataArray': datasets['tasmax_rcp45_2015_1_stats']['tasmax_mean'], + 'Dataset_vars': datasets['tasmax_rcp45_2015_2_stats'], + 'Dataset_dims': datasets['tasmax_rcp85_2015_1_perc']}, + line_kw={'DataArray': {'color': '#8a2be2'}, 'Dataset_vars': {'color': '#ffa500'}, 'Dataset_dims': {'color': '#468499'} - }, legend = 'full') + }, legend='edge') # test with non-ensemble DS -# test with pct_dim_ens_ds +timeseries(rand_ds) +#test different length arrays +timeseries({'random': rand_ds,'rcp45_2015_1': datasets['tasmax_rcp45_2015_1_perc']}) +# diff --git a/spirograph/matplotlib/util_fcts.py b/spirograph/matplotlib/util_fcts.py index fb294f62..c4a138d3 100644 --- a/spirograph/matplotlib/util_fcts.py +++ b/spirograph/matplotlib/util_fcts.py @@ -1,17 +1,30 @@ - import pandas as pd import re import warnings import xarray as xr +import matplotlib as mpl + def empty_dict(param): + """ returns empty dict if input is None""" if param is None: param = {} return param + def check_timeindex(xr_dict): """ checks if the time index of Xarray objects in a dict is CFtime - and converts to pd.DatetimeIndex if true""" + and converts to pd.DatetimeIndex if true + + Parameters + _________ + xr_dict: dict + dictionary containing Xarray DataArrays or Datasets + Returns + _______ + dict + """ + for name, xr_obj in xr_dict.items(): if 'time' in xr_obj.dims: if isinstance(xr_obj.get_index('time'), xr.CFTimeIndex): @@ -21,27 +34,32 @@ def check_timeindex(xr_dict): raise ValueError('"time" dimension not found in {}'.format(xr_obj)) return xr_dict + def get_array_categ(array): - """Returns an array category + """Returns an array category, which determines how to plot + + Parameters + __________ + array: Dataset or DataArray + + Returns + _________ + str PCT_VAR_ENS: ensemble of percentiles stored as variables PCT_DIM_ENS_DA: ensemble of percentiles stored as dimension coordinates, DataArray STATS_VAR_ENS: ensemble of statistics (min, mean, max) stored as variables - NON_ENS_DS: dataset of individual lines, not an ensemble + DS: any Dataset that is not recognized as an ensemble DA: DataArray - Args: - data_dict: Xarray Dataset or DataArray - Returns - str - """ + """ if isinstance(array, xr.Dataset): if pd.notnull([re.search("_p[0-9]{1,2}", var) for var in array.data_vars]).sum() >=2: cat = "PCT_VAR_ENS" elif pd.notnull([re.search("[Mm]ax|[Mm]in", var) for var in array.data_vars]).sum() >= 2: cat = "STATS_VAR_ENS" elif pd.notnull([re.search("percentiles", dim) for dim in array.dims]).sum() == 1: - cat = "PCT_DIM_ENS_DS" ## no support for now + cat = "PCT_DIM_ENS_DS" # placeholder, no support for now else: - cat = "NON_ENS_DS" + cat = "DS" elif isinstance(array, xr.DataArray): if pd.notnull([re.search("percentiles", dim) for dim in array.dims]).sum() == 1: @@ -57,12 +75,19 @@ def get_array_categ(array): def get_attributes(string, xr_obj): """ Fetches attributes or dims corresponding to keys from Xarray objects. Looks in - Dataset attributes first, and then looks in DataArray. - Args: - xr_obj: Xarray DataArray or Dataset - str: string corresponding to an attribute key - Returns: - Xarray attribute value as string + Dataset attributes first, then looks in DataArray. + + Parameters + _________ + string: str + string corresponding to an attribute name + xr_obj: DataArray or Dataset + the Xarray object containing the attributes + + Returns + _______ + str + Xarray attribute value as string or empty string if not found """ if string in xr_obj.attrs: return xr_obj.attrs[string] @@ -74,22 +99,27 @@ def get_attributes(string, xr_obj): if string in xr_obj[list(xr_obj.data_vars)[0]].attrs: # DataArray of first variable return xr_obj[list(xr_obj.data_vars)[0]].attrs[string] - else: - warnings.warn('Attribute "{0}" not found in attributes'.format(string)) - return '' - + else: + warnings.warn('Attribute "{0}" not found in attributes'.format(string)) + return '' ## would it be better to return None? if so, need to fix ylabel in set_plot_attrs() def set_plot_attrs(attr_dict, xr_obj, ax): """ Sets plot elements according to Dataset or DataArray attributes. Uses get_attributes() to check for and return the string. - Args: - use_attrs (dict): dict containing specified attribute keys - xr_obj: Xarray DataArray. - ax: matplotlib axis - Returns: - matplotlib axis + + Parameters + __________ + use_attrs: dict + dictionary containing specified attribute keys + xr_obj: Dataset or DataArray + The Xarray object containing the attributes + ax: matplotlib axis + the matplotlib axis + Returns + ______ + matplotlib axis """ # check @@ -104,13 +134,14 @@ def set_plot_attrs(attr_dict, xr_obj, ax): ax.set_xlabel(get_attributes(attr_dict['xlabel'], xr_obj)) if 'ylabel' in attr_dict: - if 'yunits' in attr_dict and len(attr_dict['yunits']) >= 1: # second condition avoids '[]' as label - ax.set_ylabel(get_attributes(attr_dict['ylabel'], xr_obj) + ' [' + - get_attributes(attr_dict['yunits'], xr_obj) + ']') + if 'yunits' in attr_dict and len(get_attributes(attr_dict['yunits'], xr_obj)) >= 1: # second condition avoids '[]' as label + ax.set_ylabel(get_attributes(attr_dict['ylabel'], xr_obj) + ' (' + + get_attributes(attr_dict['yunits'], xr_obj) + ')') else: ax.set_ylabel(get_attributes(attr_dict['ylabel'], xr_obj)) return ax + def get_suffix(string): """ get suffix of typical Xclim variable names""" if re.search("[0-9]{1,2}$|_[Mm]ax$|_[Mm]in$|_[Mm]ean$", string): @@ -123,10 +154,15 @@ def get_suffix(string): def sort_lines(array_dict): """ Labels arrays as 'middle', 'upper' and 'lower' for ensemble plotting - Args: - array_dict: dict of arrays. - Returns: - dict + + Parameters + _______ + array_dict: dict of {'name': array}. + + Returns + _______ + dict + dictionary of {'middle': 'name', 'upper': 'name', 'lower': 'name'} """ if len(array_dict) != 3: raise ValueError('Ensembles must contain exactly three arrays') @@ -156,71 +192,65 @@ def sort_lines(array_dict): def plot_coords(ax, xr_obj): + """ place lat, lon coordinates on bottom right of plot area""" if 'lat' in xr_obj.coords and 'lon' in xr_obj.coords: text = 'lat={:.2f}, lon={:.2f}'.format(float(xr_obj['lat']), float(xr_obj['lon'])) ax.text(0.99, 0.01, text, transform=ax.transAxes, ha = 'right', va = 'bottom') else: - raise Exception('show_coords set to True, but no coordonates found in {}.coords'.format(xr_obj)) + warnings.warn('show_coords set to True, but no coordinates found in {}.coords'.format(xr_obj)) return ax - -def in_plot_legend(ax, xlim_factor=0.08, label_gap=0.03, out = False): +def split_legend(ax, out = True, axis_factor=0.15, label_gap=0.02): """ - Draws line labels at the end of each line - Args: - xlim_factor: float - percentage of the x-axis length to add at the far right of the plot - label_gap: float - percentage of the x-axis length to add as a gap between line and label - - Returns: + Draws line labels at the end of each line, or outside the plot + + Parameters + _______ + ax: matplotlib axis + the axis containing the legend + out: bool (default True) + if True, print the labels outside of plot area. if False, prolongs plot area to fit labels + axis_factor: float + if out is True, percentage of the x-axis length to add at the far right of the plot + label_gap: float + if out is True,percentage of the x-axis length to add as a gap between line and label + + Returns + ______ matplotlib axis """ + #create extra space + init_xbound = ax.get_xbound() + ax_bump = (init_xbound[1] - init_xbound[0]) * axis_factor + label_bump = (init_xbound[1] - init_xbound[0]) * label_gap - init_xlim = ax.get_xlim() - ax.set_xlim(xmin=init_xlim[0], - xmax=init_xlim[1] + (init_xlim[1] * xlim_factor)) + if out is False: + ax.set_xbound(lower=init_xbound[0], upper=init_xbound[1] + ax_bump) #get legend and plot handles, labels = ax.get_legend_handles_labels() for handle, label in zip(handles, labels): - last_pt = (handle.get_xdata()[-1], handle.get_ydata()[-1]) - last_pt_dsp = ax.transData.transform(last_pt) - last_pt_ax = ax.transAxes.inverted().transform(last_pt_dsp) - last_x = last_pt_ax[0] - last_y = last_pt_ax[1] + + last_x = handle.get_xdata()[-1] + last_y = handle.get_ydata()[-1] + + if isinstance(last_x, np.datetime64): + last_x = mpl.dates.date2num(last_x) + color = handle.get_color() ls = handle.get_linestyle() if out is False: - ax.text(last_x + (label_gap * last_x), last_y, label, - ha='left', va='center', color=color, transform=ax.transAxes) - - - if out is True: - ax.text(1.05, last_y, label, ha='left', va='center', color=color, transform=ax.transAxes) - ax.plot([1, 1.2], [last_y, last_y], ls=ls, color=color, transform=ax.transAxes) - #ax.axhline(y=last_y, xmin=last_x, xmax = 1,ls="-", color=color) + ax.text(last_x + label_bump, last_y, label, + ha='left', va='center', color=color) + else: + trans = mpl.transforms.blended_transform_factory(ax.transAxes, ax.transData) + ax.text(1.01, last_y, label, ha='left', va='center', color=color, transform=trans) return ax - -fig, ax = plt.subplots() -ax.plot([1, 2, 3], [5, 8, 9], label='1st LABEL') -ax.plot([1, 2, 3], [8, 2, 4], label='2nd LABEL') -in_plot_legend(ax, out = False) - - - -fig, [ax1, ax2] = plt.subplots(1,2) -for ax in [ax1,ax2]: - ax.plot([1,2,3],[5,8,9], label = 'YASS') - ax.plot([1,2,3],[8,2,4], label = 'YES') -in_plot_legend(ax1) -in_plot_legend(ax2, out=True) - From d2688b426c30d0af605634dbc665c7d8bbfa09d5 Mon Sep 17 00:00:00 2001 From: Beauprel <121904497+Beauprel@users.noreply.github.com> Date: Mon, 6 Feb 2023 12:17:02 -0500 Subject: [PATCH 16/25] Delete get_example_data.py --- spirograph/matplotlib/get_example_data.py | 135 ---------------------- 1 file changed, 135 deletions(-) delete mode 100644 spirograph/matplotlib/get_example_data.py diff --git a/spirograph/matplotlib/get_example_data.py b/spirograph/matplotlib/get_example_data.py deleted file mode 100644 index 3eb70048..00000000 --- a/spirograph/matplotlib/get_example_data.py +++ /dev/null @@ -1,135 +0,0 @@ -import xarray as xr -import numpy as np -import pandas as pd -import glob -from xclim import ensembles -import re - -# create NetCDFs - -## rcp4.5, 2015, 3 models -ens2015_rcp45 = glob.glob('/scen3/scenario/netcdf/ouranos/cb-oura-1.0/tasmax_day_*_rcp45_*_2015.nc') -tasmax_rcp45_2015_1 = ensembles.create_ensemble(ens2015_rcp45[3:6]) - -tasmax_rcp45_2015_1_stats = ensembles.ensemble_mean_std_max_min(tasmax_rcp45_2015_1) -tasmax_rcp45_2015_1_perc = ensembles.ensemble_percentiles(tasmax_rcp45_2015_1, values=[15, 50, 85], split=False) - -tasmax_rcp45_2015_1_stats.to_netcdf(path='/exec/abeaupre/Projects/spirograph/test_data/tasmax_rcp45_2015_1_stats.nc') -tasmax_rcp45_2015_1_perc.to_netcdf(path='/exec/abeaupre/Projects/spirograph/test_data/tasmax_rcp45_2015_1_perc.nc') - -## rcp4.5, 2015, 3 other models -ens2015_rcp45 = glob.glob('/scen3/scenario/netcdf/ouranos/cb-oura-1.0/tasmax_day_*_rcp45_*_2015.nc') -tasmax_rcp45_2015_2 = ensembles.create_ensemble(ens2015_rcp45[0:3]) - -tasmax_rcp45_2015_2_stats = ensembles.ensemble_mean_std_max_min(tasmax_rcp45_2015_2).sel(lat =slice(65,40), lon = slice(-90,-55)) -tasmax_rcp45_2015_2_perc = ensembles.ensemble_percentiles(tasmax_rcp45_2015_2, values=[15, 50, 85], split=False).sel(lat =slice(65,40), lon = slice(-90,-55)) - -tasmax_rcp45_2015_2_stats.to_netcdf(path='/exec/abeaupre/Projects/spirograph/test_data/tasmax_rcp45_2015_2_stats.nc') -tasmax_rcp45_2015_2_perc.to_netcdf(path='/exec/abeaupre/Projects/spirograph/test_data/tasmax_rcp45_2015_2_perc.nc') - -## rcp8.5, 2015, 3 other models -ens2015_rcp85 = glob.glob('/scen3/scenario/netcdf/ouranos/cb-oura-1.0/tasmax_day_*_rcp85_*_2015.nc') -tasmax_rcp85_2015_1 = ensembles.create_ensemble(ens2015_rcp85[3:6]) - -tasmax_rcp85_2015_1_stats = ensembles.ensemble_mean_std_max_min(tasmax_rcp85_2015_1).sel(lat =slice(65,40), lon = slice(-90,-55)) -tasmax_rcp85_2015_1_perc = ensembles.ensemble_percentiles(tasmax_rcp85_2015_1, values=[15, 50, 85], split=False).sel(lat =slice(65,40), lon = slice(-90,-55)) - -tasmax_rcp85_2015_1_stats.to_netcdf(path='/exec/abeaupre/Projects/spirograph/test_data/tasmax_rcp85_2015_1_stats.nc') -tasmax_rcp85_2015_1_perc.to_netcdf(path='/exec/abeaupre/Projects/spirograph/test_data/tasmax_rcp85_2015_1_perc.nc') - - -## rcp8.5, 2015, 3 other models -ens2015_rcp85 = glob.glob('/scen3/scenario/netcdf/ouranos/cb-oura-1.0/tasmax_day_*_rcp85_*_2015.nc') -tasmax_rcp85_2015_2 = ensembles.create_ensemble(ens2015_rcp85[0:3]) - -tasmax_rcp85_2015_2_stats = ensembles.ensemble_mean_std_max_min(tasmax_rcp85_2015_2).sel(lat =slice(65,40), lon = slice(-90,-55)) -tasmax_rcp85_2015_2_perc = ensembles.ensemble_percentiles(tasmax_rcp85_2015_2, values=[15, 50, 85], split=False).sel(lat =slice(65,40), lon = slice(-90,-55)) - -tasmax_rcp85_2015_2_stats.to_netcdf(path='/exec/abeaupre/Projects/spirograph/test_data/tasmax_rcp85_2015_2_stats.nc') -tasmax_rcp85_2015_2_perc.to_netcdf(path='/exec/abeaupre/Projects/spirograph/test_data/tasmax_rcp85_2015_2_perc.nc') - -## rcp4.5, 2012, 3 models -ens2012_rcp85 = glob.glob('/scen3/scenario/netcdf/ouranos/cb-oura-1.0/tasmax_day_*_rcp85_*_2012.nc') -tasmax_rcp85_2012_1 = ensembles.create_ensemble(ens2012_rcp85[5:8]) - -tasmax_rcp85_2012_1_stats = ensembles.ensemble_mean_std_max_min(tasmax_rcp85_2012_1) -tasmax_rcp85_2012_1_perc = ensembles.ensemble_percentiles(tasmax_rcp85_2012_1, values=[15, 50, 85], split=False) - -tasmax_rcp85_2012_1_stats.to_netcdf(path='/exec/abeaupre/Projects/spirograph/test_data/tasmax_rcp85_2012_1_stats.nc') -tasmax_rcp85_2012_1_perc.to_netcdf(path='/exec/abeaupre/Projects/spirograph/test_data/tasmax_rcp85_2012_1_perc.nc') - - -# import and process - - -def output_ds(paths): - - target_lat = 45.5 - target_lon = -73.6 - time_slice = slice(160,260) - - dsets = {} - - - for path in paths: - if re.search("_stats", path): - open_ds = xr.open_dataset(path, decode_timedelta=False) - var_ds = open_ds[['tasmax_mean', 'tasmax_min', 'tasmax_max']] - elif re.search("_perc", path): - open_ds = xr.open_dataset(path, decode_timedelta=False) - var_ds = open_ds.drop_dims('ts')['tasmax'] - else: - print(path, ' not _stats or _perc') - continue - - loc_ds = var_ds.sel(lat=target_lat, lon=target_lon, method='nearest') - #.convert_calendar('standard') - if time_slice: - loc_ds = loc_ds.isel(time=time_slice) - dsets[path.split(sep='/')[-1].split(sep='.')[0]] = loc_ds - - return dsets - - -paths = glob.glob('/exec/abeaupre/Projects/spirograph/test_data/tasmax*.nc') - -datasets = output_ds(paths) - - -#datasets['tasmax_rcp45_2015_1_stats'] - -# make percentile dataset - -ds_perc = xr.Dataset({'rcp45': datasets['tasmax_rcp45_2015_1_perc'], - 'rcp85': datasets['tasmax_rcp85_2015_1_perc']}) - -# Other datasets -## ensemble percentiles (pct in variables) -url_1 = 'https://pavics.ouranos.ca//twitcher/ows/proxy/thredds/dodsC/birdhouse/disk2/cccs_portal/indices/Final/BCCAQv2_CMIP6/tx_max/YS/ssp585/ensemble_percentiles/tx_max_ann_BCCAQ2v2+ANUSPLIN300_historical+ssp585_1950-2100_30ymean_percentiles.nc' -ds_pct_open = xr.open_dataset(url_1, decode_timedelta=False) - -ds_pct_1 = ds_pct_open.isel(lon=500, lat=250)[['tx_max_p50', 'tx_max_p10', 'tx_max_p90']] -da_pct_1 = ds_pct_1['tx_max_p50'] - -## randomly-generated ensemble percentiles (pct in dims). No attributes -data = np.random.rand(4,3)*25 + 300 -time = pd.date_range(start='1960-01-01', end='2020-01-01', periods=4) -pct = [15,50,95] - -da_pct_rand = xr.DataArray(data, coords = [time, pct], dims = ['time', 'percentiles']) -attr_list = ['long_name','time','standard_name','units'] -for a in attr_list: - da_pct_rand.attrs[a] = 'default' - -## randomly-generated non-ensemble dataset - -time = pd.date_range(start ='1960-01-01', end = '2020-01-01', periods = 10) -dat_1 = np.random.rand(10) * 20 -dat_2 = np.random.rand(10) * 20 -dat_3 = np.random.rand(10) * 20 - -rand_ds = xr.Dataset(data_vars={'data1': ('time', dat_1), - 'data2': ('time', dat_2), - 'data3': ('time', dat_3)}, - coords={'time': time}, - attrs={'description': 'Randomly generated time-series'}) From 18328e92fb5d729d74752417dcff9abe43c899a3 Mon Sep 17 00:00:00 2001 From: Beauprel <121904497+Beauprel@users.noreply.github.com> Date: Mon, 6 Feb 2023 12:17:21 -0500 Subject: [PATCH 17/25] Delete timeseries_test.py --- spirograph/matplotlib/timeseries_test.py | 48 ------------------------ 1 file changed, 48 deletions(-) delete mode 100644 spirograph/matplotlib/timeseries_test.py diff --git a/spirograph/matplotlib/timeseries_test.py b/spirograph/matplotlib/timeseries_test.py deleted file mode 100644 index 1583de29..00000000 --- a/spirograph/matplotlib/timeseries_test.py +++ /dev/null @@ -1,48 +0,0 @@ - -import matplotlib as mpl -import matplotlib.pyplot as plt -mpl.use("Qt5Agg") -#mpl.style.use('dark_background') # mpl.style.available - - - -# test - -## 1 . Basic plot functionality - -## simple DataArray, unlabeled -timeseries(da_pct_1, line_kw={'color': 'red'}) - -## simple DataArray, labeled -timeseries({'My data': da_pct_1}, line_kw={'My data': {'color': 'red'}}) - -## idem, with no attributes -timeseries({'Random data': da_pct_rand}) - -## simple Dataset ensemble (variables) -timeseries({'rcp45_2015_1': datasets['tasmax_rcp45_2015_1_stats']}, legend = 'full', show_coords = True) - -timeseries({'rcp45_2015_1': datasets['tasmax_rcp45_2015_1_stats']}, - line_kw={'rcp45_2015_1': {'color': 'purple'}}) - -## simple Dataset ensemble (dims), title override -my_ax = timeseries({'rcp45_2015_1': datasets['tasmax_rcp45_2015_1_perc']}, - line_kw={'rcp45_2015_1': {'color': '#daa520'}}, legend = 'full') -my_ax.set_title('The percentiles are in dimensions') - -## one DataArray, one pct Dataset, one stats Dataset -timeseries({'DataArray': datasets['tasmax_rcp45_2015_1_stats']['tasmax_mean'], - 'Dataset_vars': datasets['tasmax_rcp45_2015_2_stats'], - 'Dataset_dims': datasets['tasmax_rcp85_2015_1_perc']}, - line_kw={'DataArray': {'color': '#8a2be2'}, - 'Dataset_vars': {'color': '#ffa500'}, - 'Dataset_dims': {'color': '#468499'} - }, legend='edge') - -# test with non-ensemble DS -timeseries(rand_ds) - -#test different length arrays -timeseries({'random': rand_ds,'rcp45_2015_1': datasets['tasmax_rcp45_2015_1_perc']}) - -# From f06fd81e2edf90dabe04b5f6906a1de54ee0e25f Mon Sep 17 00:00:00 2001 From: abeaupre Date: Mon, 6 Feb 2023 12:21:48 -0500 Subject: [PATCH 18/25] split_legend func, docstrings updated, non-ensemble DS legend fixed --- spirograph/matplotlib/timeseries.py | 1 + 1 file changed, 1 insertion(+) diff --git a/spirograph/matplotlib/timeseries.py b/spirograph/matplotlib/timeseries.py index 45aa7bb4..8c6b7a90 100644 --- a/spirograph/matplotlib/timeseries.py +++ b/spirograph/matplotlib/timeseries.py @@ -8,6 +8,7 @@ def timeseries(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None, legend='lines', show_coords = True): """ Plots time series from 1D dataframes or datasets + Parameters __________ data: dict or Dataset/DataArray From 12066fca46e9181b9b64a4dd2fa69be3306e42dc Mon Sep 17 00:00:00 2001 From: Beauprel <121904497+Beauprel@users.noreply.github.com> Date: Wed, 8 Feb 2023 17:01:33 -0500 Subject: [PATCH 19/25] Apply suggestions from code review Co-authored-by: juliettelavoie --- spirograph/matplotlib/timeseries.py | 12 ++++++------ spirograph/matplotlib/util_fcts.py | 3 ++- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/spirograph/matplotlib/timeseries.py b/spirograph/matplotlib/timeseries.py index 45aa7bb4..02b03e2b 100644 --- a/spirograph/matplotlib/timeseries.py +++ b/spirograph/matplotlib/timeseries.py @@ -11,20 +11,20 @@ def timeseries(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None, legend= Parameters __________ data: dict or Dataset/DataArray - dictionary of labeled Xarray DataArrays or Datasets + Input data to plot. It can be a DataArrays, Datasets or a dictionary of DataArrays or Datasets. ax: matplotlib axis - user-specified matplotlib axis + Matplotlib axis on which to plot. use_attrs: dict dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description') sub_kw: dict - matplotlib subplots kwargs in the format {'param': value} + Arguments to pass to `plt.subplots()`. Only works if `ax` is not provided. line_kw: dict - matplotlib or xarray line kwargs in the format {'param': value} + Arguments to pass the `plot()` function. This is used to change how the line looks. legend: str 'full' (lines and shading), 'lines' (lines only), 'in_plot' (end of lines), 'edge' (out of plot), 'none' (no legend) show_coords: bool - show latitude, longitude coordinates at the bottom right of the figure + show latitude and longitude coordinates at the bottom right of the figure Returns _______ matplotlib axis @@ -42,7 +42,7 @@ def timeseries(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None, legend= ## type for name, arr in data.items(): if not isinstance(arr, (xr.Dataset, xr.DataArray)): - raise TypeError('data must contain Xarray-type objects') + raise TypeError('`data` must contain a xr.Dataset, a xr.DataArray or a dictionary of xr.Dataset/ xr.DataArray.') ## 'time' dimension and calendar format data = check_timeindex(data) diff --git a/spirograph/matplotlib/util_fcts.py b/spirograph/matplotlib/util_fcts.py index c4a138d3..a0577c4c 100644 --- a/spirograph/matplotlib/util_fcts.py +++ b/spirograph/matplotlib/util_fcts.py @@ -3,6 +3,7 @@ import warnings import xarray as xr import matplotlib as mpl +import numpy as np def empty_dict(param): @@ -44,7 +45,7 @@ def get_array_categ(array): Returns _________ - str + array: str PCT_VAR_ENS: ensemble of percentiles stored as variables PCT_DIM_ENS_DA: ensemble of percentiles stored as dimension coordinates, DataArray STATS_VAR_ENS: ensemble of statistics (min, mean, max) stored as variables From 7aa591658192c38ba0060f0e5da7ee118d0eb4cd Mon Sep 17 00:00:00 2001 From: abeaupre Date: Thu, 9 Feb 2023 19:56:22 -0500 Subject: [PATCH 20/25] - adressed comments from PR-12 - Simplified how xlabel 'time' is assigned - Changed default-setting method for use_attrs - Added support for Dataset with multiple percentile ensembles --- spirograph/matplotlib/timeseries.py | 157 +++++++++++++++++++--------- spirograph/matplotlib/util_fcts.py | 19 ++-- 2 files changed, 115 insertions(+), 61 deletions(-) diff --git a/spirograph/matplotlib/timeseries.py b/spirograph/matplotlib/timeseries.py index 11b266fd..84db6145 100644 --- a/spirograph/matplotlib/timeseries.py +++ b/spirograph/matplotlib/timeseries.py @@ -1,13 +1,16 @@ import xarray as xr import matplotlib.pyplot as plt +from spirograph.matplotlib.util_fcts import empty_dict,check_timeindex, get_array_categ, \ + sort_lines, get_suffix,set_plot_attrs, plot_latlon, split_legend # To add # translation to fr # logo -def timeseries(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None, legend='lines', show_coords = True): +def timeseries(data, ax=None, use_attrs=None, fig_kw=None, plot_kw=None, legend='lines', show_coords = True): """ - Plots time series from 1D dataframes or datasets + Plots time series from 1D Xarray Datasets or DataArrays as line plots.Recognizes Xclim percentiles + or statistics ensembles and plots as shaded regions with a central line. Parameters __________ @@ -16,29 +19,24 @@ def timeseries(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None, legend= ax: matplotlib axis Matplotlib axis on which to plot. use_attrs: dict - dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description') - sub_kw: dict + Dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description'). + Default value is {'title': 'long_name', 'ylabel': 'standard_name', 'yunits': 'units'}. + Only the keys found in the default dict can be used + fig_kw: dict Arguments to pass to `plt.subplots()`. Only works if `ax` is not provided. - line_kw: dict - Arguments to pass the `plot()` function. This is used to change how the line looks. + plot_kw: dict + Arguments to pass the `plot()` function. Changes how the line looks. + Must be a nested dictionary if data is a dictionary. legend: str 'full' (lines and shading), 'lines' (lines only), 'in_plot' (end of lines), - 'edge' (out of plot), 'none' (no legend) + 'edge' (out of plot), 'none' (no legend). show_coords: bool - show latitude and longitude coordinates at the bottom right of the figure + Show latitude and longitude coordinates at the bottom right of the figure. Returns _______ matplotlib axis """ - # if only one data input, insert in dict - non_dict_data = False - - if type(data) != dict: - data = {'_no_label': data} # mpl excludes labels starting with "_" from legend - line_kw = {'_no_label': empty_dict(line_kw)} - non_dict_data = True - # basic checks ## type for name, arr in data.items(): @@ -48,31 +46,40 @@ def timeseries(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None, legend= ## 'time' dimension and calendar format data = check_timeindex(data) - # set default kwargs - plot_attrs = {'title': 'long_name', - 'xlabel': 'time', - 'ylabel': 'standard_name', - 'yunits': 'units'} - plot_sub_kw = {} + #create empty dicts if None + use_attrs = empty_dict(use_attrs) + fig_kw = empty_dict(fig_kw) + plot_kw = empty_dict(plot_kw) - if non_dict_data is True: - plot_line_kw = {} + # if only one data input, insert in dict. + non_dict_data = False + if type(data) != dict: + non_dict_data = True + data = {'_no_label': data} # mpl excludes labels starting with "_" from legend + plot_kw = {'_no_label': empty_dict(plot_kw)} + + #assign keys to plot_kw if empty + if len(plot_kw) == 0: + for name, arr in data.items(): + plot_kw[name] = {} else: - plot_line_kw = {name: {} for name in data.keys()} + for name, arr in data.items(): + if name not in plot_kw: + raise Exception('plot_kw must be a nested dictionary with keys corresponding to the keys in "data"') - # add/replace default kwargs with user inputs - for user_dict, attr_dict in zip([use_attrs, sub_kw, line_kw], - [plot_attrs, plot_sub_kw, plot_line_kw]): - if user_dict: - for k, v in user_dict.items(): - attr_dict[k] = v - kwargs = {'sub_kw': plot_sub_kw, 'line_kw': plot_line_kw} + # set default use_attrs values + use_attrs.setdefault('title', 'long_name') + use_attrs.setdefault('ylabel', 'standard_name') + use_attrs.setdefault('yunits', 'units') + + kwargs = {'fig_kw': fig_kw, 'plot_kw': plot_kw} + print(kwargs) # set fig, ax if not provided if not ax: - fig, ax = plt.subplots(**kwargs['sub_kw']) + fig, ax = plt.subplots(**kwargs['fig_kw']) # build dictionary of array 'categories', which determine how to plot data array_categ = {name: get_array_categ(array) for name, array in data.items()} @@ -82,12 +89,56 @@ def timeseries(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None, legend= for name, arr in data.items(): - # add name in line kwargs if not there, to avoid error due to double 'label' args in plot() - if 'label' not in kwargs['line_kw'][name]: - kwargs['line_kw'][name]['label'] = name + # add 'label':name in line kwargs if not there, to avoid error due to double 'label' args in plot() + if 'label' not in kwargs['plot_kw'][name]: + kwargs['plot_kw'][name]['label'] = name + + + # Dataset containing percentile ensembles + if array_categ[name] == 'PCT_DIM_ENS_DS': + for k, sub_arr in arr.data_vars.items(): + if non_dict_data is True: + sub_name = sub_arr.name + else: + sub_name = kwargs['plot_kw'][name]['label'] + "_" + sub_arr.name + print('subname:', sub_name) + + # extract each percentile array from the dims + array_data = {} + for pct in sub_arr.percentiles: + array_data[str(int(pct))] = sub_arr.sel(percentiles=int(pct)) + + # create a dictionary labeling the middle, upper and lower line + sorted_lines = sort_lines(array_data) + + # plot line while temporary changing label to sub_name + store_label = kwargs['plot_kw'][name]['label'] + kwargs['plot_kw'][name]['label'] = sub_name + + lines_dict[sub_name] = ax.plot(array_data[sorted_lines['middle']]['time'], + array_data[sorted_lines['middle']].values, + **kwargs['plot_kw'][name]) + + kwargs['plot_kw'][name]['label'] = store_label + + # plot shading + fill_between_label = "{}th-{}th percentiles".format(get_suffix(sorted_lines['lower']), + get_suffix(sorted_lines['upper'])) - # ensembles - if array_categ[name] in ['PCT_VAR_ENS', 'STATS_VAR_ENS', 'PCT_DIM_ENS_DA']: + if legend != 'full': + fill_between_label = None + + ax.fill_between(array_data[sorted_lines['lower']]['time'], + array_data[sorted_lines['lower']].values, + array_data[sorted_lines['upper']].values, + color=lines_dict[sub_name][0].get_color(), + linewidth=0.0, alpha=0.2, label=fill_between_label) + + + + + # other ensembles + elif array_categ[name] in ['PCT_VAR_ENS', 'STATS_VAR_ENS', 'PCT_DIM_ENS_DA']: # extract each array from the datasets array_data = {} @@ -104,7 +155,7 @@ def timeseries(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None, legend= # plot line lines_dict[name] = ax.plot(array_data[sorted_lines['middle']]['time'], array_data[sorted_lines['middle']].values, - **kwargs['line_kw'][name]) + **kwargs['plot_kw'][name]) # plot shading if array_categ[name] in ['PCT_VAR_ENS', 'PCT_DIM_ENS_DA']: @@ -119,7 +170,8 @@ def timeseries(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None, legend= array_data[sorted_lines['lower']].values, array_data[sorted_lines['upper']].values, color=lines_dict[name][0].get_color(), - linewidth = 0.0, alpha=0.2, label=fill_between_label) + linewidth=0.0, alpha=0.2, label=fill_between_label) + # non-ensemble Datasets elif array_categ[name] in ['DS']: @@ -127,34 +179,37 @@ def timeseries(data, ax=None, use_attrs=None, sub_kw=None, line_kw=None, legend= if non_dict_data is True: sub_name = sub_arr.name else: - sub_name = kwargs['line_kw'][name]['label'] + "_" + sub_arr.name + sub_name = kwargs['plot_kw'][name]['label'] + "_" + sub_arr.name - #put sub_name in line_kwargs to label correctly on plot, store the + #put sub_name in plot_kwargs to label correctly on plot, store the # original, and put it back after - store_label = kwargs['line_kw'][name]['label'] - kwargs['line_kw'][name]['label'] = sub_name - lines_dict[sub_name] = ax.plot(sub_arr['time'], sub_arr.values, **kwargs['line_kw'][name]) - kwargs['line_kw'][name]['label'] = store_label + store_label = kwargs['plot_kw'][name]['label'] + kwargs['plot_kw'][name]['label'] = sub_name + lines_dict[sub_name] = ax.plot(sub_arr['time'], sub_arr.values, **kwargs['plot_kw'][name]) + kwargs['plot_kw'][name]['label'] = store_label # non-ensemble DataArrays elif array_categ[name] in ['DA']: - lines_dict[name] = ax.plot(arr['time'], arr.values, **kwargs['line_kw'][name]) + lines_dict[name] = ax.plot(arr['time'], arr.values, **kwargs['plot_kw'][name]) else: - raise Exception('Data structure not supported') + raise Exception('Data structure not supported') # can probably be removed along with elif logic above, + # given that get_array_categ() checks also + + # add/modify plot elements according to the first entry. - set_plot_attrs(plot_attrs, list(data.values())[0], ax) + set_plot_attrs(use_attrs, list(data.values())[0], ax) - # other plot elements (check overlap with Stylesheet!) + # other plot elements (will be replaced by Stylesheet) ax.margins(x=0, y=0.05) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) if show_coords: - plot_coords(ax, list(data.values())[0]) + plot_latlon(ax, list(data.values())[0]) if legend is not None: # non_dict_data is False and if legend == 'in_plot': diff --git a/spirograph/matplotlib/util_fcts.py b/spirograph/matplotlib/util_fcts.py index a0577c4c..7289c1fd 100644 --- a/spirograph/matplotlib/util_fcts.py +++ b/spirograph/matplotlib/util_fcts.py @@ -57,7 +57,7 @@ def get_array_categ(array): cat = "PCT_VAR_ENS" elif pd.notnull([re.search("[Mm]ax|[Mm]in", var) for var in array.data_vars]).sum() >= 2: cat = "STATS_VAR_ENS" - elif pd.notnull([re.search("percentiles", dim) for dim in array.dims]).sum() == 1: + elif 'percentiles' in array.dims: cat = "PCT_DIM_ENS_DS" # placeholder, no support for now else: cat = "DS" @@ -100,9 +100,9 @@ def get_attributes(string, xr_obj): if string in xr_obj[list(xr_obj.data_vars)[0]].attrs: # DataArray of first variable return xr_obj[list(xr_obj.data_vars)[0]].attrs[string] - else: - warnings.warn('Attribute "{0}" not found in attributes'.format(string)) - return '' ## would it be better to return None? if so, need to fix ylabel in set_plot_attrs() + else: + warnings.warn('Attribute "{0}" not found in attributes'.format(string)) + return '' ## would it be better to return None? if so, need to fix ylabel in set_plot_attrs() def set_plot_attrs(attr_dict, xr_obj, ax): @@ -125,15 +125,14 @@ def set_plot_attrs(attr_dict, xr_obj, ax): """ # check for key in attr_dict: - if key not in ['title','xlabel', 'ylabel', 'yunits']: + if key not in ['title', 'ylabel', 'yunits']: warnings.warn('Use_attrs element "{}" not supported'.format(key)) + ax.set_xlabel('time') # check_timeindex() already checks for 'time' + if 'title' in attr_dict: ax.set_title(get_attributes(attr_dict['title'], xr_obj), wrap=True) - if 'xlabel' in attr_dict: - ax.set_xlabel(get_attributes(attr_dict['xlabel'], xr_obj)) - if 'ylabel' in attr_dict: if 'yunits' in attr_dict and len(get_attributes(attr_dict['yunits'], xr_obj)) >= 1: # second condition avoids '[]' as label ax.set_ylabel(get_attributes(attr_dict['ylabel'], xr_obj) + ' (' + @@ -192,14 +191,14 @@ def sort_lines(array_dict): return sorted_lines -def plot_coords(ax, xr_obj): +def plot_latlon(ax, xr_obj): """ place lat, lon coordinates on bottom right of plot area""" if 'lat' in xr_obj.coords and 'lon' in xr_obj.coords: text = 'lat={:.2f}, lon={:.2f}'.format(float(xr_obj['lat']), float(xr_obj['lon'])) ax.text(0.99, 0.01, text, transform=ax.transAxes, ha = 'right', va = 'bottom') else: - warnings.warn('show_coords set to True, but no coordinates found in {}.coords'.format(xr_obj)) + warnings.warn('show_coords set to True, but "lat" and/or "lon" not found in {}.coords'.format(xr_obj)) return ax From abb2f2c0851196122c29bdb63d79f347d10d3bdd Mon Sep 17 00:00:00 2001 From: abeaupre Date: Fri, 10 Feb 2023 11:58:13 -0500 Subject: [PATCH 21/25] - fixed label bugs - minor name updates --- spirograph/matplotlib/timeseries.py | 34 ++++++++++++++--------------- spirograph/matplotlib/util_fcts.py | 22 ++++++++----------- 2 files changed, 25 insertions(+), 31 deletions(-) diff --git a/spirograph/matplotlib/timeseries.py b/spirograph/matplotlib/timeseries.py index 84db6145..f8c40b6f 100644 --- a/spirograph/matplotlib/timeseries.py +++ b/spirograph/matplotlib/timeseries.py @@ -1,13 +1,11 @@ import xarray as xr import matplotlib.pyplot as plt -from spirograph.matplotlib.util_fcts import empty_dict,check_timeindex, get_array_categ, \ - sort_lines, get_suffix,set_plot_attrs, plot_latlon, split_legend +from spirograph.matplotlib.util_fcts import empty_dict, check_timeindex, get_array_categ, \ + sort_lines, get_suffix, set_plot_attrs, split_legend, plot_lat_lon -# To add -# translation to fr -# logo +# Todo: translation to fr, logo -def timeseries(data, ax=None, use_attrs=None, fig_kw=None, plot_kw=None, legend='lines', show_coords = True): +def timeseries(data, ax=None, use_attrs=None, fig_kw=None, plot_kw=None, legend='lines', show_lat_lon = True): """ Plots time series from 1D Xarray Datasets or DataArrays as line plots.Recognizes Xclim percentiles or statistics ensembles and plots as shaded regions with a central line. @@ -30,22 +28,13 @@ def timeseries(data, ax=None, use_attrs=None, fig_kw=None, plot_kw=None, legend= legend: str 'full' (lines and shading), 'lines' (lines only), 'in_plot' (end of lines), 'edge' (out of plot), 'none' (no legend). - show_coords: bool + show_lat_lon: bool Show latitude and longitude coordinates at the bottom right of the figure. Returns _______ matplotlib axis """ - # basic checks - ## type - for name, arr in data.items(): - if not isinstance(arr, (xr.Dataset, xr.DataArray)): - raise TypeError('`data` must contain a xr.Dataset, a xr.DataArray or a dictionary of xr.Dataset/ xr.DataArray.') - - ## 'time' dimension and calendar format - data = check_timeindex(data) - #create empty dicts if None use_attrs = empty_dict(use_attrs) @@ -68,6 +57,15 @@ def timeseries(data, ax=None, use_attrs=None, fig_kw=None, plot_kw=None, legend= if name not in plot_kw: raise Exception('plot_kw must be a nested dictionary with keys corresponding to the keys in "data"') + # basic checks + ## type + for name, arr in data.items(): + if not isinstance(arr, (xr.Dataset, xr.DataArray)): + raise TypeError('`data` must contain a xr.Dataset, a xr.DataArray or a dictionary of xr.Dataset/ xr.DataArray.') + + ## 'time' dimension and calendar format + data = check_timeindex(data) + # set default use_attrs values use_attrs.setdefault('title', 'long_name') @@ -208,8 +206,8 @@ def timeseries(data, ax=None, use_attrs=None, fig_kw=None, plot_kw=None, legend= ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) - if show_coords: - plot_latlon(ax, list(data.values())[0]) + if show_lat_lon: + plot_lat_lon(ax, list(data.values())[0]) if legend is not None: # non_dict_data is False and if legend == 'in_plot': diff --git a/spirograph/matplotlib/util_fcts.py b/spirograph/matplotlib/util_fcts.py index 7289c1fd..040d9569 100644 --- a/spirograph/matplotlib/util_fcts.py +++ b/spirograph/matplotlib/util_fcts.py @@ -93,12 +93,8 @@ def get_attributes(string, xr_obj): if string in xr_obj.attrs: return xr_obj.attrs[string] - elif string in xr_obj.dims: - return string # special case for 'time' because DataArray and Dataset dims are not the same types - - elif isinstance(xr_obj, xr.Dataset): - if string in xr_obj[list(xr_obj.data_vars)[0]].attrs: # DataArray of first variable - return xr_obj[list(xr_obj.data_vars)[0]].attrs[string] + elif isinstance(xr_obj, xr.Dataset) and string in xr_obj[list(xr_obj.data_vars)[0]].attrs: # DataArray of first variable + return xr_obj[list(xr_obj.data_vars)[0]].attrs[string] else: warnings.warn('Attribute "{0}" not found in attributes'.format(string)) @@ -191,19 +187,19 @@ def sort_lines(array_dict): return sorted_lines -def plot_latlon(ax, xr_obj): +def plot_lat_lon(ax, xr_obj): """ place lat, lon coordinates on bottom right of plot area""" if 'lat' in xr_obj.coords and 'lon' in xr_obj.coords: text = 'lat={:.2f}, lon={:.2f}'.format(float(xr_obj['lat']), - float(xr_obj['lon'])) + float(xr_obj['lon'])) ax.text(0.99, 0.01, text, transform=ax.transAxes, ha = 'right', va = 'bottom') else: - warnings.warn('show_coords set to True, but "lat" and/or "lon" not found in {}.coords'.format(xr_obj)) - + warnings.warn('show_latlon set to True, but "lat" and/or "lon" not found in {}.coords'.format(xr_obj)) return ax -def split_legend(ax, out = True, axis_factor=0.15, label_gap=0.02): +def split_legend(ax, in_plot = False, axis_factor=0.15, label_gap=0.02): + # TODO: check for and fix overlapping labels """ Draws line labels at the end of each line, or outside the plot @@ -229,7 +225,7 @@ def split_legend(ax, out = True, axis_factor=0.15, label_gap=0.02): ax_bump = (init_xbound[1] - init_xbound[0]) * axis_factor label_bump = (init_xbound[1] - init_xbound[0]) * label_gap - if out is False: + if in_plot is True: ax.set_xbound(lower=init_xbound[0], upper=init_xbound[1] + ax_bump) #get legend and plot @@ -246,7 +242,7 @@ def split_legend(ax, out = True, axis_factor=0.15, label_gap=0.02): color = handle.get_color() ls = handle.get_linestyle() - if out is False: + if in_plot is True: ax.text(last_x + label_bump, last_y, label, ha='left', va='center', color=color) else: From dae50986195581e7b046d3d074462c03a3462061 Mon Sep 17 00:00:00 2001 From: abeaupre Date: Wed, 15 Feb 2023 13:58:36 -0500 Subject: [PATCH 22/25] - labelling system for non-ensemble Datasets - minor bugs fixed --- spirograph/matplotlib/timeseries.py | 74 +++++++++++++++-------------- spirograph/matplotlib/util_fcts.py | 64 +++++++++++++------------ 2 files changed, 72 insertions(+), 66 deletions(-) diff --git a/spirograph/matplotlib/timeseries.py b/spirograph/matplotlib/timeseries.py index f8c40b6f..6447ed6b 100644 --- a/spirograph/matplotlib/timeseries.py +++ b/spirograph/matplotlib/timeseries.py @@ -7,35 +7,33 @@ def timeseries(data, ax=None, use_attrs=None, fig_kw=None, plot_kw=None, legend='lines', show_lat_lon = True): """ - Plots time series from 1D Xarray Datasets or DataArrays as line plots.Recognizes Xclim percentiles - or statistics ensembles and plots as shaded regions with a central line. + Plot time series from 1D Xarray Datasets or DataArrays as line plots. Parameters __________ data: dict or Dataset/DataArray - Input data to plot. It can be a DataArrays, Datasets or a dictionary of DataArrays or Datasets. + Input data to plot. It can be a DataArray, Dataset or a dictionary of DataArrays and/or Datasets. ax: matplotlib axis Matplotlib axis on which to plot. use_attrs: dict Dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description'). Default value is {'title': 'long_name', 'ylabel': 'standard_name', 'yunits': 'units'}. - Only the keys found in the default dict can be used + Only the keys found in the default dict can be used. fig_kw: dict Arguments to pass to `plt.subplots()`. Only works if `ax` is not provided. plot_kw: dict Arguments to pass the `plot()` function. Changes how the line looks. - Must be a nested dictionary if data is a dictionary. - legend: str + If 'data' is a dictionary, must be a nested dictionary with the same keys as 'data'. + legend: str (default 'lines') 'full' (lines and shading), 'lines' (lines only), 'in_plot' (end of lines), 'edge' (out of plot), 'none' (no legend). - show_lat_lon: bool - Show latitude and longitude coordinates at the bottom right of the figure. + show_lat_lon: bool (default True) + If True, show latitude and longitude coordinates at the bottom right of the figure. Returns _______ matplotlib axis """ - #create empty dicts if None use_attrs = empty_dict(use_attrs) fig_kw = empty_dict(fig_kw) @@ -72,12 +70,11 @@ def timeseries(data, ax=None, use_attrs=None, fig_kw=None, plot_kw=None, legend= use_attrs.setdefault('ylabel', 'standard_name') use_attrs.setdefault('yunits', 'units') - kwargs = {'fig_kw': fig_kw, 'plot_kw': plot_kw} - print(kwargs) + # set fig, ax if not provided if not ax: - fig, ax = plt.subplots(**kwargs['fig_kw']) + fig, ax = plt.subplots(**fig_kw) # build dictionary of array 'categories', which determine how to plot data array_categ = {name: get_array_categ(array) for name, array in data.items()} @@ -87,9 +84,8 @@ def timeseries(data, ax=None, use_attrs=None, fig_kw=None, plot_kw=None, legend= for name, arr in data.items(): - # add 'label':name in line kwargs if not there, to avoid error due to double 'label' args in plot() - if 'label' not in kwargs['plot_kw'][name]: - kwargs['plot_kw'][name]['label'] = name + # add 'label':name in ax.plot() kwargs if not there, to avoid error due to double 'label' args + plot_kw[name].setdefault('label', name) # Dataset containing percentile ensembles @@ -98,8 +94,7 @@ def timeseries(data, ax=None, use_attrs=None, fig_kw=None, plot_kw=None, legend= if non_dict_data is True: sub_name = sub_arr.name else: - sub_name = kwargs['plot_kw'][name]['label'] + "_" + sub_arr.name - print('subname:', sub_name) + sub_name = plot_kw[name]['label'] + "_" + sub_arr.name # extract each percentile array from the dims array_data = {} @@ -110,14 +105,14 @@ def timeseries(data, ax=None, use_attrs=None, fig_kw=None, plot_kw=None, legend= sorted_lines = sort_lines(array_data) # plot line while temporary changing label to sub_name - store_label = kwargs['plot_kw'][name]['label'] - kwargs['plot_kw'][name]['label'] = sub_name + store_label = plot_kw[name]['label'] + plot_kw[name]['label'] = sub_name lines_dict[sub_name] = ax.plot(array_data[sorted_lines['middle']]['time'], array_data[sorted_lines['middle']].values, - **kwargs['plot_kw'][name]) + **plot_kw[name]) - kwargs['plot_kw'][name]['label'] = store_label + plot_kw[name]['label'] = store_label # plot shading fill_between_label = "{}th-{}th percentiles".format(get_suffix(sorted_lines['lower']), @@ -153,7 +148,7 @@ def timeseries(data, ax=None, use_attrs=None, fig_kw=None, plot_kw=None, legend= # plot line lines_dict[name] = ax.plot(array_data[sorted_lines['middle']]['time'], array_data[sorted_lines['middle']].values, - **kwargs['plot_kw'][name]) + **plot_kw[name]) # plot shading if array_categ[name] in ['PCT_VAR_ENS', 'PCT_DIM_ENS_DA']: @@ -177,23 +172,28 @@ def timeseries(data, ax=None, use_attrs=None, fig_kw=None, plot_kw=None, legend= if non_dict_data is True: sub_name = sub_arr.name else: - sub_name = kwargs['plot_kw'][name]['label'] + "_" + sub_arr.name + sub_name = plot_kw[name]['label'] + "_" + sub_arr.name - #put sub_name in plot_kwargs to label correctly on plot, store the - # original, and put it back after - store_label = kwargs['plot_kw'][name]['label'] - kwargs['plot_kw'][name]['label'] = sub_name - lines_dict[sub_name] = ax.plot(sub_arr['time'], sub_arr.values, **kwargs['plot_kw'][name]) - kwargs['plot_kw'][name]['label'] = store_label + #label will be modified, so store now and put back later + store_label = plot_kw[name]['label'] + + # if kwargs are specified by user, all lines are the same and we want one legend entry + if len(plot_kw[name]) >= 2: # 'label' is there by default + lines_dict[sub_name] = ax.plot(sub_arr['time'], sub_arr.values, **plot_kw[name]) + plot_kw[name]['label'] = '' # makes sure label only appears once + else: + plot_kw[name]['label'] = sub_name + lines_dict[sub_name] = ax.plot(sub_arr['time'], sub_arr.values, **plot_kw[name]) + plot_kw[name]['label'] = store_label # non-ensemble DataArrays elif array_categ[name] in ['DA']: - lines_dict[name] = ax.plot(arr['time'], arr.values, **kwargs['plot_kw'][name]) + lines_dict[name] = ax.plot(arr['time'], arr.values, **plot_kw[name]) else: - raise Exception('Data structure not supported') # can probably be removed along with elif logic above, - # given that get_array_categ() checks also + raise Exception('Data structure not supported') # can probably be removed along with elif logic above, + # given that get_array_categ() checks also @@ -209,11 +209,13 @@ def timeseries(data, ax=None, use_attrs=None, fig_kw=None, plot_kw=None, legend= if show_lat_lon: plot_lat_lon(ax, list(data.values())[0]) - if legend is not None: # non_dict_data is False and - if legend == 'in_plot': - split_legend(ax, out=False) + if legend is not None: + if not ax.get_legend_handles_labels()[0]: # check if legend is empty + pass + elif legend == 'in_plot': + split_legend(ax, in_plot=True) elif legend == 'edge': - split_legend(ax, out=True) + split_legend(ax, in_plot=False) else: ax.legend() diff --git a/spirograph/matplotlib/util_fcts.py b/spirograph/matplotlib/util_fcts.py index 040d9569..bebd42c3 100644 --- a/spirograph/matplotlib/util_fcts.py +++ b/spirograph/matplotlib/util_fcts.py @@ -5,22 +5,24 @@ import matplotlib as mpl import numpy as np +warnings.simplefilter('always', UserWarning) + def empty_dict(param): - """ returns empty dict if input is None""" + """ Return empty dict if input is None. """ if param is None: param = {} return param def check_timeindex(xr_dict): - """ checks if the time index of Xarray objects in a dict is CFtime - and converts to pd.DatetimeIndex if true + """ Check if the time index of Xarray objects in a dict is CFtime + and convert to pd.DatetimeIndex if True. Parameters _________ xr_dict: dict - dictionary containing Xarray DataArrays or Datasets + Dictionary containing Xarray DataArrays or Datasets. Returns _______ dict @@ -37,11 +39,12 @@ def check_timeindex(xr_dict): def get_array_categ(array): - """Returns an array category, which determines how to plot + """Return an array category, which determines how to plot. Parameters __________ array: Dataset or DataArray + The array being categorized. Returns _________ @@ -75,15 +78,15 @@ def get_array_categ(array): def get_attributes(string, xr_obj): """ - Fetches attributes or dims corresponding to keys from Xarray objects. Looks in - Dataset attributes first, then looks in DataArray. + Fetch attributes or dims corresponding to keys from Xarray objects. Look in + Dataset attributes first, then in the first variable (DataArray) of the Dataset. Parameters _________ string: str - string corresponding to an attribute name + String corresponding to an attribute name. xr_obj: DataArray or Dataset - the Xarray object containing the attributes + The Xarray object containing the attributes. Returns _______ @@ -97,23 +100,23 @@ def get_attributes(string, xr_obj): return xr_obj[list(xr_obj.data_vars)[0]].attrs[string] else: - warnings.warn('Attribute "{0}" not found in attributes'.format(string)) - return '' ## would it be better to return None? if so, need to fix ylabel in set_plot_attrs() + warnings.warn('Attribute "{}" not found.'.format(string)) + return '' def set_plot_attrs(attr_dict, xr_obj, ax): """ - Sets plot elements according to Dataset or DataArray attributes. Uses get_attributes() + Set plot elements according to Dataset or DataArray attributes. Uses get_attributes() to check for and return the string. Parameters __________ - use_attrs: dict - dictionary containing specified attribute keys + attr_dict: dict + Dictionary containing specified attribute keys. xr_obj: Dataset or DataArray - The Xarray object containing the attributes + The Xarray object containing the attributes. ax: matplotlib axis - the matplotlib axis + The matplotlib axis of the plot. Returns ______ matplotlib axis @@ -139,7 +142,7 @@ def set_plot_attrs(attr_dict, xr_obj, ax): def get_suffix(string): - """ get suffix of typical Xclim variable names""" + """ Get suffix of typical Xclim variable names. """ if re.search("[0-9]{1,2}$|_[Mm]ax$|_[Mm]in$|_[Mm]ean$", string): suffix = re.search("[0-9]{1,2}$|[Mm]ax$|[Mm]in$|[Mm]ean$", string).group() return suffix @@ -149,16 +152,17 @@ def get_suffix(string): def sort_lines(array_dict): """ - Labels arrays as 'middle', 'upper' and 'lower' for ensemble plotting + Label arrays as 'middle', 'upper' and 'lower' for ensemble plotting. Parameters _______ - array_dict: dict of {'name': array}. + array_dict: dict + Dictionary of format {'name': array...}. Returns _______ dict - dictionary of {'middle': 'name', 'upper': 'name', 'lower': 'name'} + Dictionary of {'middle': 'name', 'upper': 'name', 'lower': 'name'}. """ if len(array_dict) != 3: raise ValueError('Ensembles must contain exactly three arrays') @@ -188,31 +192,31 @@ def sort_lines(array_dict): def plot_lat_lon(ax, xr_obj): - """ place lat, lon coordinates on bottom right of plot area""" + """ Place lat, lon coordinates on bottom right of plot area.""" if 'lat' in xr_obj.coords and 'lon' in xr_obj.coords: text = 'lat={:.2f}, lon={:.2f}'.format(float(xr_obj['lat']), float(xr_obj['lon'])) ax.text(0.99, 0.01, text, transform=ax.transAxes, ha = 'right', va = 'bottom') else: - warnings.warn('show_latlon set to True, but "lat" and/or "lon" not found in {}.coords'.format(xr_obj)) + warnings.warn('show_lat_lon set to True, but "lat" and/or "lon" not found in coords') return ax def split_legend(ax, in_plot = False, axis_factor=0.15, label_gap=0.02): # TODO: check for and fix overlapping labels """ - Draws line labels at the end of each line, or outside the plot + Drawline labels at the end of each line, or outside the plot. Parameters _______ ax: matplotlib axis - the axis containing the legend - out: bool (default True) - if True, print the labels outside of plot area. if False, prolongs plot area to fit labels - axis_factor: float - if out is True, percentage of the x-axis length to add at the far right of the plot - label_gap: float - if out is True,percentage of the x-axis length to add as a gap between line and label + The axis containing the legend. + in_plot: bool (default False) + If True, prolong plot area to fit labels. If False, print labels outside of plot area. + axis_factor: float (default 0.15) + If in_plot is True, fraction of the x-axis length to add at the far right of the plot. + label_gap: float (default 0.02) + If in_plot is True, fraction of the x-axis length to add as a gap between line and label. Returns ______ From 8adde5734e642ae12e3abf4936355a886487adf2 Mon Sep 17 00:00:00 2001 From: abeaupre Date: Thu, 16 Feb 2023 17:52:08 -0500 Subject: [PATCH 23/25] - renamed files - added support for climate ensembles - added utils plot_realizations() and fill_between_label() - etc. --- .../matplotlib/{timeseries.py => plot.py} | 117 ++++++++---------- .../matplotlib/{util_fcts.py => utils.py} | 74 +++++++++-- 2 files changed, 115 insertions(+), 76 deletions(-) rename spirograph/matplotlib/{timeseries.py => plot.py} (63%) rename spirograph/matplotlib/{util_fcts.py => utils.py} (76%) diff --git a/spirograph/matplotlib/timeseries.py b/spirograph/matplotlib/plot.py similarity index 63% rename from spirograph/matplotlib/timeseries.py rename to spirograph/matplotlib/plot.py index 6447ed6b..e55b16b4 100644 --- a/spirograph/matplotlib/timeseries.py +++ b/spirograph/matplotlib/plot.py @@ -1,7 +1,8 @@ import xarray as xr import matplotlib.pyplot as plt -from spirograph.matplotlib.util_fcts import empty_dict, check_timeindex, get_array_categ, \ - sort_lines, get_suffix, set_plot_attrs, split_legend, plot_lat_lon +import warnings +from spirograph.matplotlib.utils import empty_dict, check_timeindex, get_array_categ, \ + sort_lines, set_plot_attrs, split_legend, plot_lat_lon, plot_realizations, fill_between_label # Todo: translation to fr, logo @@ -34,7 +35,7 @@ def timeseries(data, ax=None, use_attrs=None, fig_kw=None, plot_kw=None, legend= matplotlib axis """ - #create empty dicts if None + # create empty dicts if None use_attrs = empty_dict(use_attrs) fig_kw = empty_dict(fig_kw) plot_kw = empty_dict(plot_kw) @@ -46,7 +47,7 @@ def timeseries(data, ax=None, use_attrs=None, fig_kw=None, plot_kw=None, legend= data = {'_no_label': data} # mpl excludes labels starting with "_" from legend plot_kw = {'_no_label': empty_dict(plot_kw)} - #assign keys to plot_kw if empty + # assign keys to plot_kw if empty if len(plot_kw) == 0: for name, arr in data.items(): plot_kw[name] = {} @@ -55,87 +56,78 @@ def timeseries(data, ax=None, use_attrs=None, fig_kw=None, plot_kw=None, legend= if name not in plot_kw: raise Exception('plot_kw must be a nested dictionary with keys corresponding to the keys in "data"') - # basic checks - ## type + # check: type for name, arr in data.items(): if not isinstance(arr, (xr.Dataset, xr.DataArray)): raise TypeError('`data` must contain a xr.Dataset, a xr.DataArray or a dictionary of xr.Dataset/ xr.DataArray.') - ## 'time' dimension and calendar format + # check: 'time' dimension and calendar format data = check_timeindex(data) - # set default use_attrs values use_attrs.setdefault('title', 'long_name') use_attrs.setdefault('ylabel', 'standard_name') use_attrs.setdefault('yunits', 'units') - - # set fig, ax if not provided if not ax: fig, ax = plt.subplots(**fig_kw) - # build dictionary of array 'categories', which determine how to plot data + # dict of array 'categories' array_categ = {name: get_array_categ(array) for name, array in data.items()} - # get data and plot lines_dict = {} # created to facilitate accessing line properties later + # get data and plot for name, arr in data.items(): - # add 'label':name in ax.plot() kwargs if not there, to avoid error due to double 'label' args - plot_kw[name].setdefault('label', name) + # remove 'label' to avoid error due to double 'label' args + if 'label' in plot_kw[name]: + del plot_kw[name]['label'] + warnings.warn('"label" entry in plot_kw[{}] will be ignored.'.format(name)) - # Dataset containing percentile ensembles - if array_categ[name] == 'PCT_DIM_ENS_DS': + if array_categ[name] == "ENS_REALS_DA": + plot_realizations(ax, arr, name, plot_kw, non_dict_data) + + elif array_categ[name] == "ENS_REALS_DS": + if len(arr.data_vars) >= 2: + raise Exception('To plot multiple ensembles containing realizations, use DataArrays outside a Dataset') for k, sub_arr in arr.data_vars.items(): - if non_dict_data is True: - sub_name = sub_arr.name - else: - sub_name = plot_kw[name]['label'] + "_" + sub_arr.name + plot_realizations(ax, sub_arr, name, plot_kw, non_dict_data) + + elif array_categ[name] == 'ENS_PCT_DIM_DS': + for k, sub_arr in arr.data_vars.items(): + + sub_name = sub_arr.name if non_dict_data is True else (name + "_" + sub_arr.name) # extract each percentile array from the dims array_data = {} - for pct in sub_arr.percentiles: - array_data[str(int(pct))] = sub_arr.sel(percentiles=int(pct)) + for pct in sub_arr.percentiles.values: + array_data[str(pct)] = sub_arr.sel(percentiles=pct) # create a dictionary labeling the middle, upper and lower line sorted_lines = sort_lines(array_data) - # plot line while temporary changing label to sub_name - store_label = plot_kw[name]['label'] - plot_kw[name]['label'] = sub_name - + # plot lines_dict[sub_name] = ax.plot(array_data[sorted_lines['middle']]['time'], - array_data[sorted_lines['middle']].values, - **plot_kw[name]) - - plot_kw[name]['label'] = store_label - - # plot shading - fill_between_label = "{}th-{}th percentiles".format(get_suffix(sorted_lines['lower']), - get_suffix(sorted_lines['upper'])) - - if legend != 'full': - fill_between_label = None + array_data[sorted_lines['middle']].values, + label=sub_name, **plot_kw[name]) ax.fill_between(array_data[sorted_lines['lower']]['time'], array_data[sorted_lines['lower']].values, array_data[sorted_lines['upper']].values, color=lines_dict[sub_name][0].get_color(), - linewidth=0.0, alpha=0.2, label=fill_between_label) - - + linewidth=0.0, alpha=0.2, + label=fill_between_label(sorted_lines, name, array_categ, legend)) # other ensembles - elif array_categ[name] in ['PCT_VAR_ENS', 'STATS_VAR_ENS', 'PCT_DIM_ENS_DA']: + elif array_categ[name] in ['ENS_PCT_VAR_DS', 'ENS_STATS_VAR_DS', 'ENS_PCT_DIM_DA']: # extract each array from the datasets array_data = {} - if array_categ[name] == 'PCT_DIM_ENS_DA': + if array_categ[name] == 'ENS_PCT_DIM_DA': for pct in arr.percentiles: array_data[str(int(pct))] = arr.sel(percentiles=int(pct)) else: @@ -145,51 +137,40 @@ def timeseries(data, ax=None, use_attrs=None, fig_kw=None, plot_kw=None, legend= # create a dictionary labeling the middle, upper and lower line sorted_lines = sort_lines(array_data) - # plot line + # plot lines_dict[name] = ax.plot(array_data[sorted_lines['middle']]['time'], array_data[sorted_lines['middle']].values, - **plot_kw[name]) - - # plot shading - if array_categ[name] in ['PCT_VAR_ENS', 'PCT_DIM_ENS_DA']: - fill_between_label = "{}th-{}th percentiles".format(get_suffix(sorted_lines['lower']), - get_suffix(sorted_lines['upper'])) - if array_categ[name] in ['STATS_VAR_ENS']: - fill_between_label = "min-max range" - if legend != 'full': - fill_between_label = None + label=name, **plot_kw[name]) ax.fill_between(array_data[sorted_lines['lower']]['time'], array_data[sorted_lines['lower']].values, array_data[sorted_lines['upper']].values, color=lines_dict[name][0].get_color(), - linewidth=0.0, alpha=0.2, label=fill_between_label) + linewidth=0.0, alpha=0.2, + label=fill_between_label(sorted_lines, name, array_categ, legend)) # non-ensemble Datasets - elif array_categ[name] in ['DS']: + elif array_categ[name] == "DS": + + ignore_label = False for k, sub_arr in arr.data_vars.items(): - if non_dict_data is True: - sub_name = sub_arr.name - else: - sub_name = plot_kw[name]['label'] + "_" + sub_arr.name - #label will be modified, so store now and put back later - store_label = plot_kw[name]['label'] + sub_name = sub_arr.name if non_dict_data is True else (name + "_" + sub_arr.name) # if kwargs are specified by user, all lines are the same and we want one legend entry - if len(plot_kw[name]) >= 2: # 'label' is there by default - lines_dict[sub_name] = ax.plot(sub_arr['time'], sub_arr.values, **plot_kw[name]) - plot_kw[name]['label'] = '' # makes sure label only appears once + if plot_kw[name]: + label = name if not ignore_label else '' + ignore_label = True else: - plot_kw[name]['label'] = sub_name - lines_dict[sub_name] = ax.plot(sub_arr['time'], sub_arr.values, **plot_kw[name]) - plot_kw[name]['label'] = store_label + label = sub_name + + lines_dict[sub_name] = ax.plot(sub_arr['time'], sub_arr.values, label=label, **plot_kw[name]) # non-ensemble DataArrays elif array_categ[name] in ['DA']: - lines_dict[name] = ax.plot(arr['time'], arr.values, **plot_kw[name]) + lines_dict[name] = ax.plot(arr['time'], arr.values, label=name, **plot_kw[name]) else: raise Exception('Data structure not supported') # can probably be removed along with elif logic above, diff --git a/spirograph/matplotlib/util_fcts.py b/spirograph/matplotlib/utils.py similarity index 76% rename from spirograph/matplotlib/util_fcts.py rename to spirograph/matplotlib/utils.py index bebd42c3..37caf318 100644 --- a/spirograph/matplotlib/util_fcts.py +++ b/spirograph/matplotlib/utils.py @@ -49,25 +49,32 @@ def get_array_categ(array): Returns _________ array: str - PCT_VAR_ENS: ensemble of percentiles stored as variables - PCT_DIM_ENS_DA: ensemble of percentiles stored as dimension coordinates, DataArray - STATS_VAR_ENS: ensemble of statistics (min, mean, max) stored as variables + ENS_PCT_VAR_DS: ensemble percentiles stored as variables + ENS_PCT_DIM_DA: ensemble percentiles stored as dimension coordinates, DataArray + ENS_PCT_DIM_DS: ensemble percentiles stored as dimension coordinates, DataSet + ENS_STATS_VAR_DS: ensemble statistics (min, mean, max) stored as variables + ENS_REALS_DA: ensemble with 'realization' dim, as DataArray + ENS_REALS_DS: ensemble with 'realization' dim, as Dataset DS: any Dataset that is not recognized as an ensemble DA: DataArray """ if isinstance(array, xr.Dataset): - if pd.notnull([re.search("_p[0-9]{1,2}", var) for var in array.data_vars]).sum() >=2: - cat = "PCT_VAR_ENS" + if pd.notnull([re.search("_p[0-9]{1,2}", var) for var in array.data_vars]).sum() >= 2: + cat = "ENS_PCT_VAR_DS" elif pd.notnull([re.search("[Mm]ax|[Mm]in", var) for var in array.data_vars]).sum() >= 2: - cat = "STATS_VAR_ENS" + cat = "ENS_STATS_VAR_DS" elif 'percentiles' in array.dims: - cat = "PCT_DIM_ENS_DS" # placeholder, no support for now + cat = "ENS_PCT_DIM_DS" + elif 'realization' in array.dims: + cat = "ENS_REALS_DS" else: cat = "DS" elif isinstance(array, xr.DataArray): if pd.notnull([re.search("percentiles", dim) for dim in array.dims]).sum() == 1: - cat = "PCT_DIM_ENS_DA" + cat = "ENS_PCT_DIM_DA" + elif 'realization' in array.dims: + cat = "ENS_REALS_DA" else: cat = "DA" else: @@ -254,3 +261,54 @@ def split_legend(ax, in_plot = False, axis_factor=0.15, label_gap=0.02): ax.text(1.01, last_y, label, ha='left', va='center', color=color, transform=trans) return ax + + +def plot_realizations(ax, da, name, plot_kw, non_dict_data): + """ Plot realizations from a DataArray, inside or outside a Dataset + + Parameters + _________ + da: DataArray + The DataArray containing the realizations + name: str + The label to be used in the first part of a composite label. + Can be the name of the parent Dataset or that of the DataArray + plot_kw: dict + Dictionary of kwargs coming from the timeseries() input + ax: matplotlib axis + The Matplotlib axis + + Returns + _______ + Matplotlib axis + """ + ignore_label = False + + for r in da.realization.values: + + if plot_kw[name]: # if kwargs (all lines identical) + if not ignore_label: # if label not already in legend + label = '' if non_dict_data is True else name + ignore_label = True + else: + label = '' + else: + label = str(r) if non_dict_data is True else (name + '_' + str(r)) + + ax.plot(da.sel(realization=r)['time'], da.sel(realization=r).values, + label=label, **plot_kw[name]) + + return ax + + +def fill_between_label(sorted_lines, name, array_categ, legend): + """ Create label for shading""" + if legend != 'full': + label = None + elif array_categ[name] in ['ENS_PCT_VAR_DS','ENS_PCT_DIM_DS','ENS_PCT_DIM_DA']: + label = "{}th-{}th percentiles".format(get_suffix(sorted_lines['lower']), + get_suffix(sorted_lines['upper'])) + elif array_categ[name] == 'ENS_STATS_VAR_DS': + label = 'min-max range' + + return label From 7a0f1022b59195d43a0c9627a9f8602c6fe64a0a Mon Sep 17 00:00:00 2001 From: abeaupre Date: Fri, 17 Feb 2023 14:30:23 -0500 Subject: [PATCH 24/25] - changed import method - moved function from utils to plot --- spirograph/matplotlib/plot.py | 47 ++++++++++++++++++++++++++++++---- spirograph/matplotlib/utils.py | 47 ++++------------------------------ 2 files changed, 47 insertions(+), 47 deletions(-) diff --git a/spirograph/matplotlib/plot.py b/spirograph/matplotlib/plot.py index e55b16b4..858b8db6 100644 --- a/spirograph/matplotlib/plot.py +++ b/spirograph/matplotlib/plot.py @@ -1,10 +1,47 @@ import xarray as xr import matplotlib.pyplot as plt import warnings -from spirograph.matplotlib.utils import empty_dict, check_timeindex, get_array_categ, \ - sort_lines, set_plot_attrs, split_legend, plot_lat_lon, plot_realizations, fill_between_label +from spirograph.matplotlib.utils import * + + +def _plot_realizations(ax, da, name, plot_kw, non_dict_data): + """ Plot realizations from a DataArray, inside or outside a Dataset. + + Parameters + _________ + da: DataArray + The DataArray containing the realizations + name: str + The label to be used in the first part of a composite label. + Can be the name of the parent Dataset or that of the DataArray + plot_kw: dict + Dictionary of kwargs coming from the timeseries() input + ax: matplotlib axis + The Matplotlib axis + + Returns + _______ + Matplotlib axis + """ + ignore_label = False + + for r in da.realization.values: + + if plot_kw[name]: # if kwargs (all lines identical) + if not ignore_label: # if label not already in legend + label = '' if non_dict_data is True else name + ignore_label = True + else: + label = '' + else: + label = str(r) if non_dict_data is True else (name + '_' + str(r)) + + ax.plot(da.sel(realization=r)['time'], da.sel(realization=r).values, + label=label, **plot_kw[name]) + + return ax + -# Todo: translation to fr, logo def timeseries(data, ax=None, use_attrs=None, fig_kw=None, plot_kw=None, legend='lines', show_lat_lon = True): """ @@ -88,13 +125,13 @@ def timeseries(data, ax=None, use_attrs=None, fig_kw=None, plot_kw=None, legend= if array_categ[name] == "ENS_REALS_DA": - plot_realizations(ax, arr, name, plot_kw, non_dict_data) + _plot_realizations(ax, arr, name, plot_kw, non_dict_data) elif array_categ[name] == "ENS_REALS_DS": if len(arr.data_vars) >= 2: raise Exception('To plot multiple ensembles containing realizations, use DataArrays outside a Dataset') for k, sub_arr in arr.data_vars.items(): - plot_realizations(ax, sub_arr, name, plot_kw, non_dict_data) + _plot_realizations(ax, sub_arr, name, plot_kw, non_dict_data) elif array_categ[name] == 'ENS_PCT_DIM_DS': for k, sub_arr in arr.data_vars.items(): diff --git a/spirograph/matplotlib/utils.py b/spirograph/matplotlib/utils.py index 37caf318..4731859c 100644 --- a/spirograph/matplotlib/utils.py +++ b/spirograph/matplotlib/utils.py @@ -61,7 +61,7 @@ def get_array_categ(array): if isinstance(array, xr.Dataset): if pd.notnull([re.search("_p[0-9]{1,2}", var) for var in array.data_vars]).sum() >= 2: cat = "ENS_PCT_VAR_DS" - elif pd.notnull([re.search("[Mm]ax|[Mm]in", var) for var in array.data_vars]).sum() >= 2: + elif pd.notnull([re.search("_[Mm]ax|_[Mm]in", var) for var in array.data_vars]).sum() >= 2: cat = "ENS_STATS_VAR_DS" elif 'percentiles' in array.dims: cat = "ENS_PCT_DIM_DS" @@ -71,7 +71,7 @@ def get_array_categ(array): cat = "DS" elif isinstance(array, xr.DataArray): - if pd.notnull([re.search("percentiles", dim) for dim in array.dims]).sum() == 1: + if 'percentiles' in array.dims: cat = "ENS_PCT_DIM_DA" elif 'realization' in array.dims: cat = "ENS_REALS_DA" @@ -262,47 +262,8 @@ def split_legend(ax, in_plot = False, axis_factor=0.15, label_gap=0.02): return ax - -def plot_realizations(ax, da, name, plot_kw, non_dict_data): - """ Plot realizations from a DataArray, inside or outside a Dataset - - Parameters - _________ - da: DataArray - The DataArray containing the realizations - name: str - The label to be used in the first part of a composite label. - Can be the name of the parent Dataset or that of the DataArray - plot_kw: dict - Dictionary of kwargs coming from the timeseries() input - ax: matplotlib axis - The Matplotlib axis - - Returns - _______ - Matplotlib axis - """ - ignore_label = False - - for r in da.realization.values: - - if plot_kw[name]: # if kwargs (all lines identical) - if not ignore_label: # if label not already in legend - label = '' if non_dict_data is True else name - ignore_label = True - else: - label = '' - else: - label = str(r) if non_dict_data is True else (name + '_' + str(r)) - - ax.plot(da.sel(realization=r)['time'], da.sel(realization=r).values, - label=label, **plot_kw[name]) - - return ax - - def fill_between_label(sorted_lines, name, array_categ, legend): - """ Create label for shading""" + """ Create label for shading in line plots.""" if legend != 'full': label = None elif array_categ[name] in ['ENS_PCT_VAR_DS','ENS_PCT_DIM_DS','ENS_PCT_DIM_DA']: @@ -310,5 +271,7 @@ def fill_between_label(sorted_lines, name, array_categ, legend): get_suffix(sorted_lines['upper'])) elif array_categ[name] == 'ENS_STATS_VAR_DS': label = 'min-max range' + else: + label = None return label From d646e204b0632b9c8d58dc926febb5228c40b2e7 Mon Sep 17 00:00:00 2001 From: Beauprel <121904497+Beauprel@users.noreply.github.com> Date: Mon, 20 Feb 2023 16:57:40 -0500 Subject: [PATCH 25/25] Apply suggestions from code review Co-authored-by: juliettelavoie --- spirograph/matplotlib/plot.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/spirograph/matplotlib/plot.py b/spirograph/matplotlib/plot.py index 858b8db6..a46283af 100644 --- a/spirograph/matplotlib/plot.py +++ b/spirograph/matplotlib/plot.py @@ -10,14 +10,14 @@ def _plot_realizations(ax, da, name, plot_kw, non_dict_data): Parameters _________ da: DataArray - The DataArray containing the realizations + The DataArray containing the realizations. name: str The label to be used in the first part of a composite label. - Can be the name of the parent Dataset or that of the DataArray + Can be the name of the parent Dataset or that of the DataArray. plot_kw: dict - Dictionary of kwargs coming from the timeseries() input + Dictionary of kwargs coming from the timeseries() input. ax: matplotlib axis - The Matplotlib axis + The Matplotlib axis. Returns _______ @@ -102,8 +102,8 @@ def timeseries(data, ax=None, use_attrs=None, fig_kw=None, plot_kw=None, legend= data = check_timeindex(data) # set default use_attrs values - use_attrs.setdefault('title', 'long_name') - use_attrs.setdefault('ylabel', 'standard_name') + use_attrs.setdefault('title', 'description') + use_attrs.setdefault('ylabel', 'long_name') use_attrs.setdefault('yunits', 'units') # set fig, ax if not provided