Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use observation weights in score #342

Merged
merged 4 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions modelskill/comparison/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,6 +1000,9 @@ def score(
"""

weights = kwargs.pop("weights", None)
if weights is None:
weights = {c.name: c.weight for c in self.comparers.values()}

metric = _parse_metric(metric, self.metrics)
if not (callable(metric) or isinstance(metric, str)):
raise ValueError("metric must be a string or a function")
Expand Down
11 changes: 8 additions & 3 deletions modelskill/comparison/_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def _parse_dataset(data) -> xr.Dataset:
data["Observation"].attrs["units"] = Quantity.undefined().unit

data.attrs["modelskill_version"] = __version__

if "weight" not in data.attrs:
data.attrs["weight"] = 1.0
return data


Expand Down Expand Up @@ -470,6 +473,7 @@ def from_matched_data(
mod_items: Optional[Iterable[str | int]] = None,
aux_items: Optional[Iterable[str | int]] = None,
name: Optional[str] = None,
weight: float = 1.0,
x: Optional[float] = None,
y: Optional[float] = None,
z: Optional[float] = None,
Expand All @@ -489,6 +493,7 @@ def from_matched_data(
z=z,
quantity=quantity,
)
data.attrs["weight"] = weight
return Comparer(matched_data=data, raw_mod_data=raw_mod_data)

def __repr__(self):
Expand Down Expand Up @@ -615,11 +620,11 @@ def obs_name(self) -> str:

@property
def weight(self) -> float:
return self.data[self._obs_name].attrs["weight"]
return self.data.attrs["weight"]

@weight.setter
def weight(self, value: float) -> None:
self.data[self._obs_name].attrs["weight"] = value
self.data.attrs["weight"] = value

@property
def _unit_text(self):
Expand Down Expand Up @@ -833,7 +838,7 @@ def __add__(
matched = match_space_time(
observation=self._to_observation(), raw_mod_data=raw_mod_data # type: ignore
)
cmp = self.__class__(matched_data=matched, raw_mod_data=raw_mod_data)
cmp = Comparer(matched_data=matched, raw_mod_data=raw_mod_data)

return cmp
else:
Expand Down
3 changes: 3 additions & 0 deletions modelskill/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def from_matched(
aux_items: Optional[Iterable[str | int]] = None,
quantity: Optional[Quantity] = None,
name: Optional[str] = None,
weight: float = 1.0,
x: Optional[float] = None,
y: Optional[float] = None,
z: Optional[float] = None,
Expand Down Expand Up @@ -139,6 +140,7 @@ def from_matched(
mod_items=mod_items,
aux_items=aux_items,
name=name,
weight=weight,
x=x,
y=y,
z=z,
Expand Down Expand Up @@ -275,6 +277,7 @@ def _single_obs_compare(

raw_mod_data = {m.name: m.extract(obs) for m in mods}
matched_data = match_space_time(obs, raw_mod_data, max_model_gap)
matched_data.attrs["weight"] = obs.weight

return Comparer(matched_data=matched_data, raw_mod_data=raw_mod_data)

Expand Down
8 changes: 5 additions & 3 deletions modelskill/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class Observation(TimeSeries):
def __init__(
self,
data: xr.Dataset,
weight: float = 1.0, # TODO: cannot currently be set
weight: float,
color: str = "#d62728", # TODO: cannot currently be set
) -> None:
data["time"] = self._parse_time(data.time)
Expand Down Expand Up @@ -173,6 +173,7 @@ def __init__(
y: Optional[float] = None,
z: Optional[float] = None,
name: Optional[str] = None,
weight: float = 1.0,
quantity: Optional[Quantity] = None,
aux_items: Optional[list[int | str]] = None,
attrs: Optional[dict] = None,
Expand All @@ -194,7 +195,7 @@ def __init__(
_validate_attrs(data.attrs, attrs)
data.attrs = {**data.attrs, **(attrs or {})}

super().__init__(data=data)
super().__init__(data=data, weight=weight)

@property
def geometry(self):
Expand Down Expand Up @@ -316,6 +317,7 @@ def __init__(
*,
item: Optional[int | str] = None,
name: Optional[str] = None,
weight: float = 1.0,
x_item: Optional[int | str] = 0,
y_item: Optional[int | str] = 1,
keep_duplicates: bool | str = "first",
Expand Down Expand Up @@ -350,7 +352,7 @@ def __init__(
_validate_attrs(data.attrs, attrs)
data.attrs = {**data.attrs, **(attrs or {})}

super().__init__(data=data)
super().__init__(data=data, weight=weight)

def __repr__(self):
out = f"TrackObservation: {self.name}, n={self.n_points}"
Expand Down
3 changes: 2 additions & 1 deletion modelskill/plotting/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional, Sequence, Tuple, Union

import matplotlib.pyplot as plt
from matplotlib.axes import Axes
import numpy as np
import pandas as pd

Expand All @@ -16,7 +17,7 @@ def _get_ax(ax=None, figsize=None):
return ax


def _get_fig_ax(ax: plt.Axes | None = None, figsize=None):
def _get_fig_ax(ax: Axes | None = None, figsize=None):
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
Expand Down
88 changes: 73 additions & 15 deletions tests/test_pointcompare.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,27 +112,85 @@ def test_score(modelresult_oresund_WL, klagshamn, drogden):
s.root_mean_squared_error.data.mean() == pytest.approx(0.1986296276629835)


# def test_weighted_score(modelresult_oresund_WL, klagshamn, drogden):
# mr = modelresult_oresund_WL
def test_weighted_score(modelresult_oresund_WL):
o1 = ms.PointObservation(
"tests/testdata/smhi_2095_klagshamn.dfs0",
item=0,
x=366844,
y=6154291,
name="Klagshamn",
)
o2 = ms.PointObservation(
"tests/testdata/dmi_30357_Drogden_Fyr.dfs0",
item=0,
x=355568.0,
y=6156863.0,
quantity=ms.Quantity(
"Water Level", unit="meter"
), # not sure if this is relevant in this test
)

mr = ms.model_result("tests/testdata/Oresund2D_subset.dfsu", item=0, name="Oresund")

cc = ms.match(obs=[o1, o2], mod=mr)
unweighted = cc.score()
assert unweighted["Oresund"] == pytest.approx(0.1986296276629835)

# Weighted

o1_w = ms.PointObservation(
"tests/testdata/smhi_2095_klagshamn.dfs0",
item=0,
x=366844,
y=6154291,
name="Klagshamn",
weight=10.0,
)

o2_w = ms.PointObservation(
"tests/testdata/dmi_30357_Drogden_Fyr.dfs0",
item=0,
x=355568.0,
y=6156863.0,
quantity=ms.Quantity(
"Water Level", unit="meter"
), # not sure if this is relevant in this test
weight=0.1,
)

# cc = ms.match([klagshamn, drogden], mr)
# unweighted_skill = cc.score()
cc_w = ms.match(obs=[o1_w, o2_w], mod=mr)
weighted = cc_w.score()

# con = ms.Connector()
assert weighted["Oresund"] == pytest.approx(0.1666888485806514)

# con.add(klagshamn, mr, weight=0.9, validate=False)
# con.add(drogden, mr, weight=0.1, validate=False)
# cc = con.extract()
# weighted_skill = cc.score()
# assert unweighted_skill != weighted_skill

# obs = [klagshamn, drogden]
def test_weighted_score_from_prematched():
df = pd.DataFrame(
{"Oresund": [0.0, 1.0], "klagshamn": [0.0, 1.0], "drogden": [-1.0, 2.0]}
)

cmp1 = ms.from_matched(
df[["Oresund", "klagshamn"]],
mod_items=["Oresund"],
obs_item="klagshamn",
weight=100.0,
)
cmp2 = ms.from_matched(
df[["Oresund", "drogden"]],
mod_items=["Oresund"],
obs_item="drogden",
weight=0.0,
)
assert cmp1.weight == 100.0
assert cmp2.weight == 0.0
assert cmp1.score()["Oresund"] == pytest.approx(0.0)
assert cmp2.score()["Oresund"] == pytest.approx(1.0)

# con = ms.Connector(obs, mr, weight=[0.9, 0.1], validate=False)
# cc = con.extract()
# weighted_skill2 = cc.score()
cc = ms.ComparerCollection([cmp1, cmp2])
assert cc["klagshamn"].weight == 100.0
assert cc["drogden"].weight == 0.0

# assert weighted_skill == weighted_skill2
assert cc.score()["Oresund"] == pytest.approx(1.0)


def test_misc_properties(klagshamn, drogden):
Expand Down
Loading