Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove multi verif from classes #453

Merged
merged 19 commits into from
Sep 25, 2020
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ Breaking changes
automatically to ``member`` for probabilistic metrics. (:pr:`407`) `Aaron Spring`_.
- metric :py:func:`~climpred.metrics._threshold_brier_score` now requires ``logical``
instead of ``func`` as ``metric_kwargs``. (:pr:`388`) `Aaron Spring`_.
- Remove ability to add multiple observations to
:py:class:`~climpred.classes.HindcastEnsemble`. This makes current and future
development much easier. (:pr:`453`) `Riley X. Brady`_

New Features
------------
Expand Down
1 change: 1 addition & 0 deletions ci/requirements/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dependencies:
- jupyterlab
- matplotlib
- nbsphinx
- nc-time-axis
- netcdf4
- numpy
- pandas
Expand Down
123 changes: 31 additions & 92 deletions climpred/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,20 @@ def _display_metadata(self):
summary = header + "\nInitialized Ensemble:\n"
summary += SPACE + str(self._datasets["initialized"].data_vars)[18:].strip() + "\n"
if isinstance(self, HindcastEnsemble):
# Prints out verification data names and associated variables if they exist.
# Prints out verification dataand associated variables if they exist.
# If not, just write "None".
summary += "Verification Data:\n"
if any(self._datasets["observations"]):
for key in self._datasets["observations"]:
summary += f"{key}:\n"
num_obs = len(self._datasets["observations"][key].data_vars)
for i in range(1, num_obs + 1):
summary += (
SPACE
+ str(self._datasets["observations"][key].data_vars)
.split("\n")[i]
.strip()
+ "\n"
)
num_obs = len(self._datasets["observations"].data_vars)
for i in range(1, num_obs + 1):
summary += (
SPACE
+ str(self._datasets["observations"].data_vars)
.split("\n")[i]
.strip()
+ "\n"
)
else:
summary += "Verification Data:\n"
summary += SPACE + "None\n"
elif isinstance(self, PerfectModelEnsemble):
summary += "Control:\n"
Expand Down Expand Up @@ -101,12 +99,9 @@ def _display_metadata_html(self):

if isinstance(self, HindcastEnsemble):
if any(self._datasets["observations"]):
for key in self._datasets["observations"]:
obs_repr_str = dataset_repr(self._datasets["observations"][key])
obs_repr_str = obs_repr_str.replace(
"xarray.Dataset", f"Verification Data {key}"
)
display_html(obs_repr_str, raw=True)
obs_repr_str = dataset_repr(self._datasets["observations"])
obs_repr_str = obs_repr_str.replace("xarray.Dataset", "Verification Data")
display_html(obs_repr_str, raw=True)
elif isinstance(self, PerfectModelEnsemble):
if any(self._datasets["control"]):
control_repr_str = dataset_repr(self._datasets["control"])
Expand Down Expand Up @@ -861,9 +856,9 @@ class HindcastEnsemble(PredictionEnsemble):
"""An object for climate prediction ensembles initialized by a data-like
product.

`HindcastEnsemble` is a sub-class of `PredictionEnsemble`. It tracks all
verification data associated with the prediction ensemble for easy
computation across multiple variables and products.
`HindcastEnsemble` is a sub-class of `PredictionEnsemble`. It tracks a single
verification dataset associated with the hindcast ensemble for easy computation
across multiple variables.

This object is built on `xarray` and thus requires the input object to
be an `xarray` Dataset or DataArray.
Expand All @@ -887,46 +882,20 @@ def __init__(self, xobj):
self._datasets.update({"observations": {}})
self.kind = "hindcast"

def _apply_climpred_function(self, func, input_dict=None, **kwargs):
def _apply_climpred_function(self, func, init, **kwargs):
"""Helper function to loop through verification data and apply an arbitrary
climpred function.

Args:
func (function): climpred function to apply to object.
input_dict (dict): dictionary with the following things:
* ensemble: initialized or uninitialized ensemble.
* observations: Dictionary of verification data from
``HindcastEnsemble``.
* name: name of verification data to target.
* init: bool of whether or not it's the initialized ensemble.
init (bool): Whether or not it's the initialized ensemble.
"""
hind = self._datasets["initialized"]
verif = self._datasets["observations"]
name = input_dict["name"]
init = input_dict["init"]

# Apply only to specific observations.
if name is not None:
drop_init, drop_obs = self._vars_to_drop(name, init=init)
hind = hind.drop_vars(drop_init)
verif = verif[name].drop_vars(drop_obs)
return func(hind, verif, **kwargs)
else:
# If only one observational product, just apply to that one.
if len(verif) == 1:
name = list(verif.keys())[0]
drop_init, drop_obs = self._vars_to_drop(name, init=init)
return func(hind, verif[name], **kwargs)
# Loop through verif, apply function, and store in dictionary.
# TODO: Parallelize this process.
else:
result = {}
for name in verif.keys():
drop_init, drop_obs = self._vars_to_drop(name, init=init)
result[name] = func(hind, verif[name], **kwargs)
return result
drop_init, drop_obs = self._vars_to_drop(init=init)
return func(hind.drop_vars(drop_init), verif.drop_vars(drop_obs), **kwargs)

def _vars_to_drop(self, name, init=True):
def _vars_to_drop(self, init=True):
"""Returns list of variables to drop when comparing
initialized/uninitialized to observations.

Expand All @@ -936,7 +905,6 @@ def _vars_to_drop(self, name, init=True):
from the initialized.

Args:
name (str): Short name of observations being compared to.
init (bool, default True):
If ``True``, check variables on the initialized.
If ``False``, check variables on the uninitialized.
Expand All @@ -949,21 +917,20 @@ def _vars_to_drop(self, name, init=True):
init_vars = [var for var in self._datasets["initialized"].data_vars]
else:
init_vars = [var for var in self._datasets["uninitialized"].data_vars]
obs_vars = [var for var in self._datasets["observations"][name].data_vars]
obs_vars = [var for var in self._datasets["observations"].data_vars]
# Make lists of variables to drop that aren't in common
# with one another.
init_vars_to_drop = list(set(init_vars) - set(obs_vars))
obs_vars_to_drop = list(set(obs_vars) - set(init_vars))
return init_vars_to_drop, obs_vars_to_drop

@is_xarray(1)
def add_observations(self, xobj, name):
"""Add a verification data with which to verify the initialized ensemble.
def add_observations(self, xobj):
"""Add verification data against which to verify the initialized ensemble.

Args:
xobj (xarray object): Dataset/DataArray to append to the
``HindcastEnsemble`` object.
name (str): Short name for referencing the verification data.
"""
if isinstance(xobj, xr.DataArray):
xobj = xobj.to_dataset()
Expand All @@ -975,12 +942,8 @@ def add_observations(self, xobj, name):
# Check that converted/original cftime calendar is the same as the
# initialized calendar to avoid any alignment errors.
match_calendars(self._datasets["initialized"], xobj)
# For some reason, I could only get the non-inplace method to work
# by updating the nested dictionaries separately.
datasets_obs = self._datasets["observations"].copy()
datasets = self._datasets.copy()
datasets_obs.update({name: xobj})
datasets.update({"observations": datasets_obs})
datasets.update({"observations": xobj})
return self._construct_direct(datasets, kind="hindcast")

@is_xarray(1)
Expand All @@ -1005,29 +968,16 @@ def add_uninitialized(self, xobj):
datasets.update({"uninitialized": xobj})
return self._construct_direct(datasets, kind="hindcast")

def get_observations(self, name=None):
def get_observations(self):
"""Returns xarray Datasets of the observations/verification data.

Args:
name (str, optional): Name of the observations/verification data to return.
If ``None``, return dictionary of all observations/verification data.

Returns:
Dictionary of ``xarray`` Datasets (if ``name`` is ``None``) or single
``xarray`` Dataset.
``xarray`` Dataset of observations.
"""
if name is None:
if len(self._datasets["observations"]) == 1:
key = list(self._datasets["observations"].keys())[0]
return self._datasets["observations"][key]
else:
return self._datasets["observations"]
else:
return self._datasets["observations"][name]
return self._datasets["observations"]

def verify(
self,
name=None,
reference=None,
metric=None,
comparison=None,
Expand All @@ -1042,8 +992,6 @@ def verify(
between the initialized ensemble and observations/verification data.

Args:
name (str): Short name of observations/verification data to compare to.
If ``None``, compare to all observations/verification data.
reference (str): Type of reference forecasts to also verify against the
observations. Choose one or more of ['historical', 'persistence'].
Defaults to None.
Expand Down Expand Up @@ -1073,10 +1021,7 @@ def verify(
**metric_kwargs (optional): arguments passed to ``metric``.

Returns:
Dataset of comparison results (if comparing to one observational product),
or dictionary of Datasets with keys corresponding to
observations/verification data short name.

Dataset of comparison results.
"""
# Have to do checks here since this doesn't call `compute_hindcast` directly.
# Will be refactored when `climpred` migrates to inheritance-based.
Expand Down Expand Up @@ -1183,14 +1128,9 @@ def _verify(
else:
hist = None

# TODO: Get rid of this somehow. Might use attributes.
input_dict = {
"name": name,
"init": True,
}
res = self._apply_climpred_function(
_verify,
input_dict=input_dict,
init=True,
metric=metric,
comparison=comparison,
alignment=alignment,
Expand Down Expand Up @@ -1278,7 +1218,6 @@ def bootstrap(
difference of skill between the initialized and persistence
simulations is smaller or equal to zero based on
bootstrapping with replacement.

"""
if iterations is None:
raise ValueError("Designate number of bootstrapping `iterations`.")
Expand Down
24 changes: 10 additions & 14 deletions climpred/graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,21 +315,17 @@ def plot_lead_timeseries_hindcast(
zorder=hind.lead.size - i,
)

linestyles = ["-", ":", "-.", "--"]
if len(obs) > len(linestyles):
raise ValueError(f"Please provide fewer than {len(linestyles)+1} observations.")
if len(obs) > 0:
for i, (obs_name, obs_item) in enumerate(obs.items()):
if isinstance(obs_item, xr.Dataset):
obs_item = obs_item[variable]
obs_item.plot(
ax=ax,
color="k",
lw=3,
ls=linestyles[i],
label=f"reference: {obs_name}",
zorder=hind.lead.size + 2,
)
if isinstance(obs, xr.Dataset):
obs = obs[variable]
obs.plot(
ax=ax,
color="k",
lw=3,
ls="-",
label="observations",
zorder=hind.lead.size + 2,
)

# show only one item per label in legend
handles, labels = ax.get_legend_handles_labels()
Expand Down
23 changes: 17 additions & 6 deletions climpred/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,22 @@ def non_empty_datasets(he):


def check_dataset_dims_and_data_vars(before, after, dataset):
if dataset not in ["initialized", "uninitialized", "control"]:
before = before._datasets["observations"][dataset]
after = after._datasets["observations"][dataset]
else:
before = before._datasets[dataset]
after = after._datasets[dataset]
"""Checks that dimensions, coordiantes, and variables are identical in
PredictionEnsemble `before` and PredictionEnsemble `after`.

Args:
before, after (climpred.PredictionEnsemble): PredictionEnsembles that have had
some operation performed on them.
dataset (str): Name of dataset within the PredictionEnsemble. This should be
a key to the dictionary.

Asserts:
That dimensions, coordinates, and variables are identical before and after the
transformation.
"""
before = before._datasets[dataset]
after = after._datasets[dataset]

assert before.dims == after.dims
assert list(before.data_vars) == list(after.data_vars)
assert list(before.coords) == list(after.coords)
18 changes: 7 additions & 11 deletions climpred/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def hindcast_recon_3d(hind_ds_initialized_3d, reconstruction_ds_3d):
for c in ["TLAT", "TLONG", "TAREA"]:
reconstruction_ds_3d[c] = hind_ds_initialized_3d[c]
hindcast = HindcastEnsemble(hind_ds_initialized_3d)
hindcast = hindcast.add_observations(reconstruction_ds_3d, "recon")
hindcast = hindcast.add_observations(reconstruction_ds_3d)
hindcast = hindcast - hindcast.sel(time=slice("1964", "2014")).mean("time").sel(
init=slice("1964", "2014")
).mean("init")
Expand All @@ -282,7 +282,7 @@ def hindcast_recon_3d(hind_ds_initialized_3d, reconstruction_ds_3d):
def hindcast_recon_1d_ym(hind_ds_initialized_1d, reconstruction_ds_1d):
"""HindcastEnsemble initialized with `initialized`, `uninitialzed` and `recon`."""
hindcast = HindcastEnsemble(hind_ds_initialized_1d)
hindcast = hindcast.add_observations(reconstruction_ds_1d, "recon")
hindcast = hindcast.add_observations(reconstruction_ds_1d)
hindcast = hindcast - hindcast.sel(time=slice("1964", "2014")).mean("time").sel(
init=slice("1964", "2014")
).mean("init")
Expand All @@ -296,7 +296,7 @@ def hindcast_hist_obs_1d(
"""HindcastEnsemble initialized with `initialized`, `uninitialzed` and `obs`."""
hindcast = HindcastEnsemble(hind_ds_initialized_1d)
hindcast = hindcast.add_uninitialized(hist_ds_uninitialized_1d)
hindcast = hindcast.add_observations(observations_ds_1d, "obs")
hindcast = hindcast.add_observations(observations_ds_1d)
hindcast = hindcast - hindcast.sel(time=slice("1964", "2014")).mean("time").sel(
init=slice("1964", "2014")
).mean("init")
Expand All @@ -309,10 +309,8 @@ def hindcast_recon_1d_mm(hindcast_recon_1d_ym):
time series (no grid)."""
hindcast = hindcast_recon_1d_ym.sel(time=slice("1964", "1970"))
hindcast._datasets["initialized"].lead.attrs["units"] = "months"
hindcast._datasets["observations"]["recon"] = (
hindcast._datasets["observations"]["recon"]
.resample(time="1MS")
.interpolate("linear")
hindcast._datasets["observations"] = (
hindcast._datasets["observations"].resample(time="1MS").interpolate("linear")
)
return hindcast

Expand All @@ -323,10 +321,8 @@ def hindcast_recon_1d_dm(hindcast_recon_1d_ym):
time series (no grid)."""
hindcast = hindcast_recon_1d_ym.sel(time=slice("1964", "1970"))
hindcast._datasets["initialized"].lead.attrs["units"] = "days"
hindcast._datasets["observations"]["recon"] = (
hindcast._datasets["observations"]["recon"]
.resample(time="1D")
.interpolate("linear")
hindcast._datasets["observations"] = (
hindcast._datasets["observations"].resample(time="1D").interpolate("linear")
)
return hindcast

Expand Down
Loading