Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use observation weights in score #342

Merged
merged 4 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions modelskill/comparison/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,6 +1000,9 @@ def score(
"""

weights = kwargs.pop("weights", None)
if weights is None:
weights = {c.name: c.weight for c in self.comparers.values()}

metric = _parse_metric(metric, self.metrics)
if not (callable(metric) or isinstance(metric, str)):
raise ValueError("metric must be a string or a function")
Expand Down
11 changes: 8 additions & 3 deletions modelskill/comparison/_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def _parse_dataset(data) -> xr.Dataset:
data["Observation"].attrs["units"] = Quantity.undefined().unit

data.attrs["modelskill_version"] = __version__

if "weight" not in data.attrs:
data.attrs["weight"] = 1.0
return data


Expand Down Expand Up @@ -470,6 +473,7 @@ def from_matched_data(
mod_items: Optional[Iterable[str | int]] = None,
aux_items: Optional[Iterable[str | int]] = None,
name: Optional[str] = None,
weight: float = 1.0,
x: Optional[float] = None,
y: Optional[float] = None,
z: Optional[float] = None,
Expand All @@ -489,6 +493,7 @@ def from_matched_data(
z=z,
quantity=quantity,
)
data.attrs["weight"] = weight
return Comparer(matched_data=data, raw_mod_data=raw_mod_data)

def __repr__(self):
Expand Down Expand Up @@ -615,11 +620,11 @@ def obs_name(self) -> str:

@property
def weight(self) -> float:
return self.data[self._obs_name].attrs["weight"]
return self.data.attrs["weight"]

@weight.setter
def weight(self, value: float) -> None:
self.data[self._obs_name].attrs["weight"] = value
self.data.attrs["weight"] = value

@property
def _unit_text(self):
Expand Down Expand Up @@ -833,7 +838,7 @@ def __add__(
matched = match_space_time(
observation=self._to_observation(), raw_mod_data=raw_mod_data # type: ignore
)
cmp = self.__class__(matched_data=matched, raw_mod_data=raw_mod_data)
cmp = Comparer(matched_data=matched, raw_mod_data=raw_mod_data)

return cmp
else:
Expand Down
3 changes: 3 additions & 0 deletions modelskill/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def from_matched(
aux_items: Optional[Iterable[str | int]] = None,
quantity: Optional[Quantity] = None,
name: Optional[str] = None,
weight: float = 1.0,
x: Optional[float] = None,
y: Optional[float] = None,
z: Optional[float] = None,
Expand Down Expand Up @@ -139,6 +140,7 @@ def from_matched(
mod_items=mod_items,
aux_items=aux_items,
name=name,
weight=weight,
x=x,
y=y,
z=z,
Expand Down Expand Up @@ -275,6 +277,7 @@ def _single_obs_compare(

raw_mod_data = {m.name: m.extract(obs) for m in mods}
matched_data = match_space_time(obs, raw_mod_data, max_model_gap)
matched_data.attrs["weight"] = obs.weight

return Comparer(matched_data=matched_data, raw_mod_data=raw_mod_data)

Expand Down
16 changes: 9 additions & 7 deletions modelskill/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,23 +95,23 @@ class Observation(TimeSeries):
def __init__(
self,
data: xr.Dataset,
weight: float = 1.0, # TODO: cannot currently be set
weight: float,
color: str = "#d62728", # TODO: cannot currently be set
) -> None:
data["time"] = self._parse_time(data.time)

super().__init__(data=data)
self.data[self.name].attrs["weight"] = weight
self.data[self.name].attrs["color"] = color
self.data.attrs["weight"] = weight
self.data.attrs["color"] = color

@property
def weight(self) -> float:
"""Weighting factor for skill scores"""
return self.data[self.name].attrs["weight"]
return self.data.attrs["weight"]

@weight.setter
def weight(self, value: float) -> None:
self.data[self.name].attrs["weight"] = value
self.data.attrs["weight"] = value

# TODO: move this to TimeSeries?
@staticmethod
Expand Down Expand Up @@ -173,6 +173,7 @@ def __init__(
y: Optional[float] = None,
z: Optional[float] = None,
name: Optional[str] = None,
weight: float = 1.0,
quantity: Optional[Quantity] = None,
aux_items: Optional[list[int | str]] = None,
attrs: Optional[dict] = None,
Expand All @@ -194,7 +195,7 @@ def __init__(
_validate_attrs(data.attrs, attrs)
data.attrs = {**data.attrs, **(attrs or {})}

super().__init__(data=data)
super().__init__(data=data, weight=weight)

@property
def geometry(self):
Expand Down Expand Up @@ -316,6 +317,7 @@ def __init__(
*,
item: Optional[int | str] = None,
name: Optional[str] = None,
weight: float = 1.0,
x_item: Optional[int | str] = 0,
y_item: Optional[int | str] = 1,
keep_duplicates: bool | str = "first",
Expand Down Expand Up @@ -350,7 +352,7 @@ def __init__(
_validate_attrs(data.attrs, attrs)
data.attrs = {**data.attrs, **(attrs or {})}

super().__init__(data=data)
super().__init__(data=data, weight=weight)

def __repr__(self):
out = f"TrackObservation: {self.name}, n={self.n_points}"
Expand Down
3 changes: 2 additions & 1 deletion modelskill/plotting/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional, Sequence, Tuple, Union

import matplotlib.pyplot as plt
from matplotlib.axes import Axes
import numpy as np
import pandas as pd

Expand All @@ -16,7 +17,7 @@ def _get_ax(ax=None, figsize=None):
return ax


def _get_fig_ax(ax: plt.Axes | None = None, figsize=None):
def _get_fig_ax(ax: Axes | None = None, figsize=None):
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
Expand Down
2 changes: 1 addition & 1 deletion modelskill/plotting/_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ def _plot_summary_table(
for ti in text_:
text_col_i = fig.text(x + dx, 0.6, ti, **txt_settings)
## Render, and get width
plt.draw()
# plt.draw() # TOOO this causes an error and I have no idea why it is here
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@daniel-caichac-DHI is this line important?

dx = (
dx
+ figure_transform.inverted().transform(
Expand Down
5 changes: 5 additions & 0 deletions tests/test_comparercollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,11 @@ def test_save_and_load_preserves_raw_model_data(cc, tmp_path):
assert len(cc2["fake point obs"].raw_mod_data["m1"]) == 6


def test_scatter(cc):
ax = cc.plot.scatter(skill_table=True)
assert ax is not None


def test_hist(cc):
ax = cc.sel(model="m1").plot.hist()
assert ax is not None
Expand Down
88 changes: 73 additions & 15 deletions tests/test_pointcompare.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,27 +112,85 @@ def test_score(modelresult_oresund_WL, klagshamn, drogden):
s.root_mean_squared_error.data.mean() == pytest.approx(0.1986296276629835)


# def test_weighted_score(modelresult_oresund_WL, klagshamn, drogden):
# mr = modelresult_oresund_WL
def test_weighted_score(modelresult_oresund_WL):
o1 = ms.PointObservation(
"tests/testdata/smhi_2095_klagshamn.dfs0",
item=0,
x=366844,
y=6154291,
name="Klagshamn",
)
o2 = ms.PointObservation(
"tests/testdata/dmi_30357_Drogden_Fyr.dfs0",
item=0,
x=355568.0,
y=6156863.0,
quantity=ms.Quantity(
"Water Level", unit="meter"
), # not sure if this is relevant in this test
)

mr = ms.model_result("tests/testdata/Oresund2D_subset.dfsu", item=0, name="Oresund")

cc = ms.match(obs=[o1, o2], mod=mr)
unweighted = cc.score()
assert unweighted["Oresund"] == pytest.approx(0.1986296276629835)

# Weighted

o1_w = ms.PointObservation(
"tests/testdata/smhi_2095_klagshamn.dfs0",
item=0,
x=366844,
y=6154291,
name="Klagshamn",
weight=10.0,
)

o2_w = ms.PointObservation(
"tests/testdata/dmi_30357_Drogden_Fyr.dfs0",
item=0,
x=355568.0,
y=6156863.0,
quantity=ms.Quantity(
"Water Level", unit="meter"
), # not sure if this is relevant in this test
weight=0.1,
)

# cc = ms.match([klagshamn, drogden], mr)
# unweighted_skill = cc.score()
cc_w = ms.match(obs=[o1_w, o2_w], mod=mr)
weighted = cc_w.score()

# con = ms.Connector()
assert weighted["Oresund"] == pytest.approx(0.1666888485806514)

# con.add(klagshamn, mr, weight=0.9, validate=False)
# con.add(drogden, mr, weight=0.1, validate=False)
# cc = con.extract()
# weighted_skill = cc.score()
# assert unweighted_skill != weighted_skill

# obs = [klagshamn, drogden]
def test_weighted_score_from_prematched():
df = pd.DataFrame(
{"Oresund": [0.0, 1.0], "klagshamn": [0.0, 1.0], "drogden": [-1.0, 2.0]}
)

cmp1 = ms.from_matched(
df[["Oresund", "klagshamn"]],
mod_items=["Oresund"],
obs_item="klagshamn",
weight=100.0,
)
cmp2 = ms.from_matched(
df[["Oresund", "drogden"]],
mod_items=["Oresund"],
obs_item="drogden",
weight=0.0,
)
assert cmp1.weight == 100.0
assert cmp2.weight == 0.0
assert cmp1.score()["Oresund"] == pytest.approx(0.0)
assert cmp2.score()["Oresund"] == pytest.approx(1.0)

# con = ms.Connector(obs, mr, weight=[0.9, 0.1], validate=False)
# cc = con.extract()
# weighted_skill2 = cc.score()
cc = ms.ComparerCollection([cmp1, cmp2])
assert cc["klagshamn"].weight == 100.0
assert cc["drogden"].weight == 0.0

# assert weighted_skill == weighted_skill2
assert cc.score()["Oresund"] == pytest.approx(1.0)


def test_misc_properties(klagshamn, drogden):
Expand Down
Loading