Skip to content

Commit

Permalink
Merge pull request #44 from DamienIrving/master
Browse files Browse the repository at this point in the history
Add dual plot option to moments.py
  • Loading branch information
DamienIrving authored Aug 11, 2023
2 parents f210bde + 9e078fd commit ddc2367
Showing 1 changed file with 159 additions and 77 deletions.
236 changes: 159 additions & 77 deletions unseen/moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import scipy

from . import fileio
from . import indices


logging.basicConfig(level=logging.INFO)
Expand All @@ -22,11 +23,52 @@ def calc_ci(data):
return lower_ci, upper_ci


def calc_moments(sample_da, gev_estimates=[]):
"""Calculate all the moments for a given sample."""

moments = {}
moments["mean"] = float(np.mean(sample_da))
moments["std"] = float(np.std(sample_da))
moments["skew"] = float(scipy.stats.skew(sample_da))
moments["kurtosis"] = float(scipy.stats.kurtosis(sample_da))
gev_shape, gev_loc, gev_scale = indices.fit_gev(
sample_da, user_estimates=gev_estimates
)
moments["GEV shape"] = gev_shape
moments["GEV loc"] = gev_loc
moments["GEV scale"] = gev_scale

return moments


def log_results(moments_obs, model_lower_cis, model_upper_cis, bias_corrected=False):
"""Log the results"""

if bias_corrected:
text_insert = "Bias corrected model"
else:
text_insert = "Model"

metadata = {}
for moment, obs_value in moments_obs.items():
lower_ci = model_lower_cis[moment]
upper_ci = model_upper_cis[moment]
text = f"Obs = {obs_value}, {text_insert} 95% CI ={lower_ci} to {upper_ci}"
if bias_corrected:
metadata["bias corrected " + moment] = text
else:
metadata[moment] = text
logging.info(f"{moment}: {text}")

return metadata


def create_plot(
fcst_file,
obs_file,
var,
outfile=None,
bc_fcst_file=None,
min_lead=None,
ensemble_dim="ensemble",
init_dim="init_date",
Expand All @@ -42,8 +84,10 @@ def create_plot(
Observations file containing metric of interest
var : str
Variable name (in fcst_file)
outfile : str, default None
outfile : str, optional
Path for output image file
bc_fcst_file : str, optional
Forecast file containing bias corrected metric of interest
min_lead : int, optional
Minimum lead time
ensemble_dim : str, default ensemble
Expand All @@ -57,94 +101,130 @@ def create_plot(
ds_obs = fileio.open_dataset(obs_file)
da_obs = ds_obs[var].dropna("time")
sample_size = len(da_obs)
mean_obs = float(np.mean(da_obs))
std_obs = float(np.std(da_obs))
skew_obs = float(scipy.stats.skew(da_obs))
kurtosis_obs = float(scipy.stats.kurtosis(da_obs))
moments_obs = calc_moments(da_obs)

ds_fcst = fileio.open_dataset(fcst_file)
da_fcst = ds_fcst[var]
if min_lead is not None:
da_fcst = da_fcst.where(ds_fcst[lead_dim] >= min_lead)
dims = [ensemble_dim, init_dim, lead_dim]
da_fcst_stacked = da_fcst.dropna(lead_dim).stack({"sample": dims})
moments_fcst = calc_moments(da_fcst_stacked)

if bc_fcst_file:
ds_bc_fcst = fileio.open_dataset(bc_fcst_file)
da_bc_fcst = ds_bc_fcst[var]
if min_lead is not None:
da_bc_fcst = da_bc_fcst.where(ds_bc_fcst[lead_dim] >= min_lead)
da_bc_fcst_stacked = da_bc_fcst.dropna(lead_dim).stack({"sample": dims})

bootstrap_values = {}
bootstrap_lower_ci = {}
bootstrap_upper_ci = {}
if bc_fcst_file:
bc_bootstrap_values = {}
bc_bootstrap_lower_ci = {}
bc_bootstrap_upper_ci = {}
moments = ["mean", "std", "skew", "kurtosis", "GEV shape", "GEV loc", "GEV scale"]
for moment in moments:
bootstrap_values[moment] = []
bootstrap_lower_ci[moment] = []
bootstrap_upper_ci[moment] = []
if bc_fcst_file:
bc_bootstrap_values[moment] = []
bc_bootstrap_lower_ci[moment] = []
bc_bootstrap_upper_ci[moment] = []

mean_values = []
std_values = []
skew_values = []
kurtosis_values = []
for i in range(1000):
random_sample = np.random.choice(da_fcst_stacked, sample_size)
mean = float(np.mean(random_sample))
std = float(np.std(random_sample))
skew = float(scipy.stats.skew(random_sample))
kurtosis = float(scipy.stats.kurtosis(random_sample))
mean_values.append(mean)
std_values.append(std)
skew_values.append(skew)
kurtosis_values.append(kurtosis)

mean_lower_ci, mean_upper_ci = calc_ci(mean_values)
std_lower_ci, std_upper_ci = calc_ci(std_values)
skew_lower_ci, skew_upper_ci = calc_ci(skew_values)
kurtosis_lower_ci, kurtosis_upper_ci = calc_ci(kurtosis_values)

fig = plt.figure(figsize=[15, 12])
ax1 = fig.add_subplot(221)
ax2 = fig.add_subplot(222)
ax3 = fig.add_subplot(223)
ax4 = fig.add_subplot(224)

ax1.hist(mean_values, rwidth=0.8, color="0.5")
ax1.set_title("(a) mean")
ax1.axvline(mean_lower_ci, color="0.2", linestyle="--")
ax1.axvline(mean_upper_ci, color="0.2", linestyle="--")
ax1.axvline(mean_obs)
ax1.set_ylabel("count")

ax2.hist(std_values, rwidth=0.8, color="0.5")
ax2.set_title("(b) standard deviation")
ax2.axvline(std_lower_ci, color="0.2", linestyle="--")
ax2.axvline(std_upper_ci, color="0.2", linestyle="--")
ax2.axvline(std_obs)
ax2.set_ylabel("count")

ax3.hist(skew_values, rwidth=0.8, color="0.5")
ax3.set_title("(c) skewness")
ax3.set_ylabel("count")
ax3.axvline(skew_lower_ci, color="0.2", linestyle="--")
ax3.axvline(skew_upper_ci, color="0.2", linestyle="--")
ax3.axvline(skew_obs)

ax4.hist(kurtosis_values, rwidth=0.8, color="0.5")
ax4.set_title("(d) kurtosis")
ax4.set_ylabel("count")
ax4.axvline(kurtosis_lower_ci, color="0.2", linestyle="--")
ax4.axvline(kurtosis_upper_ci, color="0.2", linestyle="--")
ax4.axvline(kurtosis_obs)

mean_text = f"Obs = {mean_obs}, Model 95% CI ={mean_lower_ci} to {mean_upper_ci}"
std_text = f"Obs = {std_obs}, Model 95% CI ={std_lower_ci} to {std_upper_ci}"
skew_text = f"Obs = {skew_obs}, Model 95% CI ={skew_lower_ci} to {skew_upper_ci}"
kurtosis_text = f"Obs = {kurtosis_obs}, Model 95% CI ={kurtosis_lower_ci} to {kurtosis_upper_ci}"
logging.info(f"Mean: {mean_text}")
logging.info(f"Standard deviation: {std_text}")
logging.info(f"Skewness: {skew_text}")
logging.info(f"Kurtosis: {kurtosis_text}")
sample_moments = calc_moments(
random_sample,
gev_estimates=[moments_fcst["GEV loc"], moments_fcst["GEV scale"]],
)
for moment in moments:
bootstrap_values[moment].append(sample_moments[moment])

if bc_fcst_file:
bc_random_sample = np.random.choice(da_bc_fcst_stacked, sample_size)
bc_sample_moments = calc_moments(
bc_random_sample,
gev_estimates=[moments_fcst["GEV loc"], moments_fcst["GEV scale"]],
)
for moment in moments:
bc_bootstrap_values[moment].append(bc_sample_moments[moment])

for moment in moments:
lower_ci, upper_ci = calc_ci(bootstrap_values[moment])
bootstrap_lower_ci[moment] = lower_ci
bootstrap_upper_ci[moment] = upper_ci
if bc_fcst_file:
bc_lower_ci, bc_upper_ci = calc_ci(bc_bootstrap_values[moment])
bc_bootstrap_lower_ci[moment] = bc_lower_ci
bc_bootstrap_upper_ci[moment] = bc_upper_ci

letters = "abcdefg"
fig = plt.figure(figsize=[15, 20])
for plotnum, moment in enumerate(moments):
ax = fig.add_subplot(4, 2, plotnum + 1)
ax.hist(
bootstrap_values[moment],
rwidth=0.8,
color="tab:blue",
alpha=0.7,
label="model",
)
ax.axvline(
bootstrap_lower_ci[moment], color="tab:blue", linestyle="--", linewidth=3.0
)
ax.axvline(
bootstrap_upper_ci[moment], color="tab:blue", linestyle="--", linewidth=3.0
)
ax.axvline(
moments_obs[moment], linewidth=4.0, color="tab:gray", label="observations"
)
if bc_fcst_file:
ax.hist(
bc_bootstrap_values[moment],
rwidth=0.8,
color="tab:orange",
alpha=0.7,
label="model (corrected)",
)
ax.axvline(
bc_bootstrap_lower_ci[moment],
color="tab:orange",
linestyle="--",
linewidth=3.0,
)
ax.axvline(
bc_bootstrap_upper_ci[moment],
color="tab:orange",
linestyle="--",
linewidth=3.0,
)
ax.set_ylabel("count")
letter = letters[plotnum]
ax.set_title(f"({letter}) {moment}")
if letter == "a":
ax.legend()

metadata = log_results(moments_obs, bootstrap_lower_ci, bootstrap_upper_ci)
if bc_fcst_file:
bc_metadata = log_results(
moments_obs,
bc_bootstrap_lower_ci,
bc_bootstrap_upper_ci,
bias_corrected=True,
)
metadata = metadata | bc_metadata

if outfile:
infile_logs = {
fcst_file: ds_fcst.attrs["history"],
obs_file: ds_obs.attrs["history"],
}
command_history = fileio.get_new_log(infile_logs=infile_logs)
metadata = {
"mean": mean_text,
"standard deviation": std_text,
"skewness": skew_text,
"kurtosis": kurtosis_text,
"history": command_history,
}
infile_logs = {obs_file: ds_obs.attrs["history"]}
if bc_fcst_file:
infile_logs[bc_fcst_file] = ds_bc_fcst.attrs["history"]
else:
infile_logs[fcst_file] = ds_fcst.attrs["history"]
metadata["history"] = fileio.get_new_log(infile_logs=infile_logs)
plt.savefig(
outfile,
bbox_inches="tight",
Expand All @@ -167,6 +247,7 @@ def _parse_command_line():
parser.add_argument("var", type=str, help="Variable name")

parser.add_argument("--outfile", type=str, default=None, help="Output file name")
parser.add_argument("--bias_file", type=str, help="Bias corrected forecast file")
parser.add_argument(
"--ensemble_dim",
type=str,
Expand Down Expand Up @@ -206,6 +287,7 @@ def _main():
args.obs_file,
args.var,
outfile=args.outfile,
bc_fcst_file=args.bias_file,
min_lead=args.min_lead,
ensemble_dim=args.ensemble_dim,
init_dim=args.init_dim,
Expand Down

0 comments on commit ddc2367

Please sign in to comment.