Skip to content

Commit

Permalink
Merge pull request #48 from DamienIrving/master
Browse files Browse the repository at this point in the history
Bug fixes
  • Loading branch information
DamienIrving authored Sep 22, 2023
2 parents 4a1af5c + 4f26103 commit 36c7f35
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 23 deletions.
42 changes: 35 additions & 7 deletions unseen/moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ def calc_moments(sample_da, gev_estimates=[]):

moments = {}
moments["mean"] = float(np.mean(sample_da))
moments["std"] = float(np.std(sample_da))
moments["standard deviation"] = float(np.std(sample_da))
moments["skew"] = float(scipy.stats.skew(sample_da))
moments["kurtosis"] = float(scipy.stats.kurtosis(sample_da))
gev_shape, gev_loc, gev_scale = indices.fit_gev(
sample_da, user_estimates=gev_estimates
)
moments["GEV shape"] = gev_shape
moments["GEV loc"] = gev_loc
moments["GEV location"] = gev_loc
moments["GEV scale"] = gev_scale

return moments
Expand Down Expand Up @@ -68,6 +68,7 @@ def create_plot(
da_obs,
da_bc_fcst=None,
outfile=None,
units=None,
ensemble_dim="ensemble",
init_dim="init_date",
lead_dim="lead_time",
Expand All @@ -85,6 +86,8 @@ def create_plot(
Bias corrected forecast data for metric of interest
outfile : str, optional
Path for output image file
units : str, optional
units for plot axis labels
ensemble_dim : str, default ensemble
Name of ensemble member dimension
init_dim : str, default init_date
Expand Down Expand Up @@ -113,7 +116,15 @@ def create_plot(
bc_bootstrap_values = {}
bc_bootstrap_lower_ci = {}
bc_bootstrap_upper_ci = {}
moments = ["mean", "std", "skew", "kurtosis", "GEV shape", "GEV loc", "GEV scale"]
moments = [
"mean",
"standard deviation",
"skew",
"kurtosis",
"GEV shape",
"GEV location",
"GEV scale",
]
for moment in moments:
bootstrap_values[moment] = []
bootstrap_lower_ci[moment] = []
Expand All @@ -127,7 +138,7 @@ def create_plot(
random_sample = np.random.choice(da_fcst_stacked, sample_size)
sample_moments = calc_moments(
random_sample,
gev_estimates=[moments_fcst["GEV loc"], moments_fcst["GEV scale"]],
gev_estimates=[moments_fcst["GEV location"], moments_fcst["GEV scale"]],
)
for moment in moments:
bootstrap_values[moment].append(sample_moments[moment])
Expand All @@ -136,7 +147,7 @@ def create_plot(
bc_random_sample = np.random.choice(da_bc_fcst_stacked, sample_size)
bc_sample_moments = calc_moments(
bc_random_sample,
gev_estimates=[moments_fcst["GEV loc"], moments_fcst["GEV scale"]],
gev_estimates=[moments_fcst["GEV location"], moments_fcst["GEV scale"]],
)
for moment in moments:
bc_bootstrap_values[moment].append(bc_sample_moments[moment])
Expand All @@ -151,7 +162,17 @@ def create_plot(
bc_bootstrap_upper_ci[moment] = bc_upper_ci

letters = "abcdefg"
fig = plt.figure(figsize=[15, 20])
units_label = units if units else da_fcst.attrs["units"]
units = {
"mean": f"mean ({units_label})",
"standard deviation": f"standard deviation ({units_label})",
"skew": "skew",
"kurtosis": "kurtosis",
"GEV shape": "shape parameter",
"GEV scale": "scale parameter",
"GEV location": "location parameter",
}
fig = plt.figure(figsize=[15, 22])
for plotnum, moment in enumerate(moments):
ax = fig.add_subplot(4, 2, plotnum + 1)
ax.hist(
Expand Down Expand Up @@ -191,6 +212,7 @@ def create_plot(
linewidth=3.0,
)
ax.set_ylabel("count")
ax.set_xlabel(units[moment])
letter = letters[plotnum]
ax.set_title(f"({letter}) {moment}")
if letter == "a":
Expand Down Expand Up @@ -255,6 +277,12 @@ def _parse_command_line():
default=None,
help="Minimum lead time",
)
parser.add_argument(
"--units",
type=str,
default=None,
help="Units label for the plot axes",
)
args = parser.parse_args()

return args
Expand Down Expand Up @@ -295,7 +323,7 @@ def _main():
da_obs,
da_bc_fcst=da_bc_fcst,
outfile=args.outfile,
min_lead=args.min_lead,
units=args.units,
ensemble_dim=args.ensemble_dim,
init_dim=args.init_dim,
lead_dim=args.lead_dim,
Expand Down
73 changes: 57 additions & 16 deletions unseen/stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,21 @@
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import genextreme as gev
import matplotlib as mpl

from . import fileio
from . import indices
from . import time_utils


def plot_dist_by_lead(ax, sample_da, metric, lead_dim="lead_time"):
mpl.rcParams["axes.titlesize"] = "xx-large"
mpl.rcParams["xtick.labelsize"] = "x-large"
mpl.rcParams["ytick.labelsize"] = "x-large"
mpl.rcParams["legend.fontsize"] = "large"
axis_label_size = "large"


def plot_dist_by_lead(ax, sample_da, metric, units=None, lead_dim="lead_time"):
"""Plot distribution curves for each lead time.
Parameters
Expand All @@ -24,6 +32,8 @@ def plot_dist_by_lead(ax, sample_da, metric, lead_dim="lead_time"):
Stacked forecast array with a sample dimension
metric : str
Metric name for plot title
units : str, optional
units for plot axis labels
lead_dim: str, default 'lead_time'
Name of the lead time dimension in sample_da
"""
Expand All @@ -41,14 +51,16 @@ def plot_dist_by_lead(ax, sample_da, metric, lead_dim="lead_time"):
ax=ax,
color=color,
label=f"lead time: {lead} year ({n_values} samples)",
linewidth=2.0,
)
ax.grid(True)
ax.set_title(f"(a) {metric} distribution by lead time")
ax.set_xlabel(sample_da.attrs["units"])
units_label = units if units else sample_da.attrs["units"]
ax.set_xlabel(units_label, fontsize=axis_label_size)
ax.legend()


def plot_dist_by_time(ax, sample_da, metric, start_years):
def plot_dist_by_time(ax, sample_da, metric, start_years, units=None):
"""Plot distribution curves for each time slice (e.g. decade).
Parameters
Expand All @@ -61,6 +73,8 @@ def plot_dist_by_time(ax, sample_da, metric, start_years):
Metric name for plot title
start_years : list
Equally spaced list of start years
units : str, optional
units for plot axis labels
"""

step = start_years[1] - start_years[0] - 1
Expand All @@ -79,10 +93,12 @@ def plot_dist_by_time(ax, sample_da, metric, start_years):
ax=ax,
color=color,
label=f"{start_year}-{end_year} ({n_values} samples)",
linewidth=2.0,
)
ax.grid(True)
ax.set_title(f"(c) {metric} distribution by year")
ax.set_xlabel(sample_da.attrs["units"])
units_label = units if units else sample_da.attrs["units"]
ax.set_xlabel(units_label, fontsize=axis_label_size)
ax.legend()


Expand Down Expand Up @@ -139,7 +155,14 @@ def plot_return(data, method, outfile=None):


def plot_return_by_lead(
ax, sample_da, metric, method, uncertainty=False, ymax=None, lead_dim="lead_time"
ax,
sample_da,
metric,
method,
uncertainty=False,
units=None,
ymax=None,
lead_dim="lead_time",
):
"""Plot return period curves for each lead time.
Expand All @@ -155,6 +178,8 @@ def plot_return_by_lead(
Method for producing return period curve
uncertainty: bool, default False
Plot 95% confidence interval
units : str, optional
units for plot axis labels
ymax : float, optional
ymax for return curve plot
lead_dim: str, default 'lead_time'
Expand All @@ -170,7 +195,7 @@ def plot_return_by_lead(
n_values = len(selection_da)
label = f"lead time {lead} ({n_values} samples)"
color = next(colors)
ax.plot(return_periods, return_values, label=label, color=color)
ax.plot(return_periods, return_values, label=label, color=color, linewidth=2.0)

if uncertainty:
random_return_values = []
Expand All @@ -187,20 +212,21 @@ def plot_return_by_lead(
lower_ci,
label="95% confidence interval",
color="0.5",
alpha=0.1,
alpha=0.3,
)

ax.grid(True)
ax.set_title(f"(b) {metric} return period by lead time")
ax.set_xscale("log")
ax.set_xlabel("return period (years)")
ax.set_ylabel(sample_da.attrs["units"])
ax.set_xlabel("return period (years)", fontsize=axis_label_size)
units_label = units if units else sample_da.attrs["units"]
ax.set_ylabel(units_label, fontsize=axis_label_size)
ax.legend()
ax.set_ylim((50, ymax))


def plot_return_by_time(
ax, sample_da, metric, start_years, method, uncertainty=False, ymax=None
ax, sample_da, metric, start_years, method, uncertainty=False, units=None, ymax=None
):
"""Plot return period curves for each time slice (e.g. decade).
Expand All @@ -218,6 +244,8 @@ def plot_return_by_time(
Method for producing return period curve
uncertainty: bool, default False
Plot 95% confidence interval
units : str, optional
units for plot axis labels
ymax : float, optional
ymax for return curve plot
"""
Expand All @@ -234,7 +262,7 @@ def plot_return_by_time(
n_years = len(selection_da)
label = f"{start_year}-{end_year} ({n_years} samples)"
color = next(colors)
ax.plot(return_periods, return_values, label=label, color=color)
ax.plot(return_periods, return_values, label=label, color=color, linewidth=2.0)

if uncertainty:
random_return_values = []
Expand All @@ -251,14 +279,15 @@ def plot_return_by_time(
lower_ci,
label="95% confidence interval",
color="0.5",
alpha=0.2,
alpha=0.3,
)

ax.grid(True)
ax.set_title(f"(d) {metric} return period by year")
ax.set_xscale("log")
ax.set_xlabel("return period (years)")
ax.set_ylabel(sample_da.attrs["units"])
ax.set_xlabel("return period (years)", fontsize=axis_label_size)
units_label = units if units else sample_da.attrs["units"]
ax.set_ylabel(units_label, fontsize=axis_label_size)
ax.set_ylim((50, ymax))
ax.legend()

Expand All @@ -270,6 +299,7 @@ def create_plot(
outfile=None,
uncertainty=False,
ymax=None,
units=None,
return_method="empirical",
ensemble_dim="ensemble",
init_dim="init_date",
Expand All @@ -291,6 +321,8 @@ def create_plot(
Plot the 95% confidence interval
ymax : float, optional
ymax for return curve plots
units : str, optional
units for plot axis labels
return_method : {'empirical', 'gev'}, default empirial
Method for fitting the return period curve
ensemble_dim : str, default ensemble
Expand All @@ -310,23 +342,25 @@ def create_plot(
ax3 = fig.add_subplot(223)
ax4 = fig.add_subplot(224)

plot_dist_by_lead(ax1, da_fcst_stacked, metric, lead_dim=lead_dim)
plot_dist_by_lead(ax1, da_fcst_stacked, metric, units=units, lead_dim=lead_dim)
plot_return_by_lead(
ax2,
da_fcst_stacked,
metric,
return_method,
uncertainty=uncertainty,
ymax=ymax,
units=units,
lead_dim=lead_dim,
)
plot_dist_by_time(ax3, da_fcst_stacked, metric, start_years)
plot_dist_by_time(ax3, da_fcst_stacked, metric, start_years, units=units)
plot_return_by_time(
ax4,
da_fcst_stacked,
metric,
start_years,
return_method,
units=units,
uncertainty=uncertainty,
ymax=ymax,
)
Expand Down Expand Up @@ -394,6 +428,12 @@ def _parse_command_line():
default="lead_time",
help="Name of lead time dimension",
)
parser.add_argument(
"--units",
type=str,
default=None,
help="Units label for the plot axes",
)
args = parser.parse_args()

return args
Expand All @@ -415,6 +455,7 @@ def _main():
return_method=args.return_method,
uncertainty=args.uncertainty,
ymax=args.ymax,
units=args.units,
ensemble_dim=args.ensemble_dim,
init_dim=args.init_dim,
lead_dim=args.lead_dim,
Expand Down

0 comments on commit 36c7f35

Please sign in to comment.