Skip to content

Commit 1f76a5d

Browse files
committed
add MetricMeter
Signed-off-by: Zhiyuan Chen <this@zyc.ai>
1 parent 413ae17 commit 1f76a5d

File tree

8 files changed

+301
-6
lines changed

8 files changed

+301
-6
lines changed

.github/workflows/push.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
runs-on: ubuntu-latest
1515
strategy:
1616
matrix:
17-
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
17+
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"]
1818
steps:
1919
- uses: actions/checkout@v3
2020
- uses: actions/setup-python@v4

danling/__init__.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,14 @@
1919

2020
from danling import metrics, modules, optim, registry, runner, tensors, typing, utils
2121

22-
from .metrics import AverageMeter, AverageMeters, MultiTaskAverageMeters
22+
from .metrics import (
23+
AverageMeter,
24+
AverageMeters,
25+
MetricMeter,
26+
MetricMeters,
27+
MultiTaskAverageMeters,
28+
MultiTaskMetricMeters,
29+
)
2330
from .registry import GlobalRegistry, Registry
2431
from .runner import AccelerateRunner, BaseRunner, TorchRunner
2532
from .tensors import NestedTensor, PNTensor, tensor
@@ -54,6 +61,9 @@
5461
"GlobalRegistry",
5562
"Metrics",
5663
"MultiTaskMetrics",
64+
"MetricMeter",
65+
"MetricMeters",
66+
"MultiTaskMetricMeters",
5767
"AverageMeter",
5868
"AverageMeters",
5969
"MultiTaskAverageMeters",

danling/metrics/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from lazy_imports import try_import
2121

2222
from .average_meter import AverageMeter, AverageMeters, MultiTaskAverageMeters
23+
from .metric_meter import MetricMeter, MetricMeters, MultiTaskMetricMeters
2324

2425
with try_import() as lazy_import:
2526
from .functional import accuracy, auprc, auroc, mcc, pearson, r2_score, rmse, spearman
@@ -28,6 +29,9 @@
2829
__all__ = [
2930
"Metrics",
3031
"MultiTaskMetrics",
32+
"MetricMeter",
33+
"MetricMeters",
34+
"MultiTaskMetricMeters",
3135
"AverageMeter",
3236
"AverageMeters",
3337
"MultiTaskAverageMeters",

danling/metrics/functional.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
# pylint: disable=redefined-builtin
1919
from __future__ import annotations
2020

21+
from collections.abc import Sequence
22+
2123
import torch
2224
from lazy_imports import try_import
2325
from torch import Tensor
@@ -180,7 +182,19 @@ def rmse(
180182
return mse(input, target).sqrt()
181183

182184

183-
def preprocess(input: Tensor | NestedTensor, target: Tensor | NestedTensor, ignored_index: int | None = None):
185+
def preprocess(
186+
input: Tensor | NestedTensor | Sequence, target: Tensor | NestedTensor | Sequence, ignored_index: int | None = None
187+
):
188+
if not isinstance(input, (Tensor, NestedTensor)):
189+
try:
190+
input = torch.tensor(input)
191+
except ValueError:
192+
input = NestedTensor(input)
193+
if not isinstance(target, (Tensor, NestedTensor)):
194+
try:
195+
target = torch.tensor(target)
196+
except ValueError:
197+
target = NestedTensor(target)
184198
if isinstance(input, NestedTensor) or isinstance(target, NestedTensor):
185199
if isinstance(input, NestedTensor) and isinstance(target, Tensor):
186200
target = input.nested_like(target, strict=False)

danling/metrics/metric_meter.py

+257
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
# DanLing
2+
# Copyright (C) 2022-Present DanLing
3+
4+
# This program is free software: you can redistribute it and/or modify
5+
# it under the terms of the following licenses:
6+
# - The Unlicense
7+
# - GNU Affero General Public License v3.0 or later
8+
# - GNU General Public License v2.0 or later
9+
# - BSD 4-Clause "Original" or "Old" License
10+
# - MIT License
11+
# - Apache License 2.0
12+
13+
# This program is distributed in the hope that it will be useful,
14+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
15+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
16+
# See the LICENSE file for more details.
17+
18+
from __future__ import annotations
19+
20+
from collections.abc import Mapping, Sequence
21+
from typing import Any, Callable, Optional, Tuple
22+
23+
from torch import Tensor
24+
25+
from ..tensors import NestedTensor
26+
from .average_meter import AverageMeter, AverageMeters, MultiTaskAverageMeters
27+
from .functional import preprocess
28+
from .utils import MultiTaskDict
29+
30+
31+
class MetricMeter(AverageMeter):
32+
r"""
33+
Computes metrics and averages them over time.
34+
35+
Attributes:
36+
metric: Metric function for computing the value.
37+
ignored_index: Index to be ignored in the computation.
38+
val: Results of current batch on current device.
39+
bat: Results of current batch on all devices.
40+
avg: Results of all results on all devices.
41+
sum: Sum of values.
42+
count: Number of values.
43+
44+
See Also:
45+
[`AverageMeter`]: Average meter for computed values.
46+
[`MetricMeters`]: Manage multiple metric meters in one object.
47+
48+
Examples:
49+
>>> from danling.metrics.functional import accuracy
50+
>>> meter = MetricMeter(accuracy)
51+
>>> meter.update([0.1, 0.8, 0.6, 0.2], [0, 1, 0, 0])
52+
>>> meter.val
53+
0.75
54+
>>> meter.avg
55+
0.75
56+
>>> meter.update([0.1, 0.7, 0.3, 0.2, 0.8, 0.4], [0, 1, 1, 0, 0, 1])
57+
>>> meter.val
58+
0.5
59+
>>> meter.avg
60+
0.6
61+
>>> meter.sum
62+
6.0
63+
>>> meter.count
64+
10
65+
>>> meter.reset()
66+
>>> meter.val
67+
0
68+
>>> meter.avg
69+
nan
70+
"""
71+
72+
metric: Callable
73+
ignored_index: Optional[int] = None
74+
75+
def __init__(self, metric: Callable, ignored_index: int | None = None) -> None:
76+
self.metric = metric
77+
self.ignored_index = ignored_index
78+
super().__init__()
79+
80+
def reset(self) -> None:
81+
r"""
82+
Resets the meter.
83+
"""
84+
85+
self.val = 0
86+
self.n = 1
87+
self.sum = 0
88+
self.count = 0
89+
90+
def update( # type: ignore[override] # pylint: disable=W0237
91+
self,
92+
input: Tensor | NestedTensor | Sequence, # pylint: disable=W0622
93+
target: Tensor | NestedTensor | Sequence,
94+
) -> None:
95+
r"""
96+
Updates the average and current value in the meter.
97+
98+
Args:
99+
value: Value to be added to the average.
100+
n: Number of values to be added.
101+
"""
102+
103+
input, target = preprocess(input, target, ignored_index=self.ignored_index)
104+
super().update(self.metric(input, target).item(), n=len(input))
105+
106+
107+
class MetricMeters(AverageMeters):
108+
r"""
109+
Manages multiple metric meters in one object.
110+
111+
Attributes:
112+
ignored_index: Index to be ignored in the computation.
113+
Defaults to None.
114+
115+
See Also:
116+
[`MetricMeter`]: Computes metrics and averages them over time.
117+
[`AverageMeters`]: Average meters for computed values.
118+
119+
>>> from danling.metrics.functional import accuracy, auroc, auprc
120+
>>> meters = MetricMeters(acc=accuracy, auroc=auroc, auprc=auprc)
121+
>>> meters.update([0.1, 0.8, 0.6, 0.2], [0, 1, 0, 0])
122+
>>> meters.sum.dict()
123+
{'acc': 3.0, 'auroc': 4.0, 'auprc': 4.0}
124+
>>> meters.count.dict()
125+
{'acc': 4, 'auroc': 4, 'auprc': 4}
126+
>>> meters['auroc'].update([0.2, 0.8], [0, 1])
127+
>>> meters.sum.dict()
128+
{'acc': 3.0, 'auroc': 6.0, 'auprc': 4.0}
129+
>>> meters.count.dict()
130+
{'acc': 4, 'auroc': 6, 'auprc': 4}
131+
>>> meters.update([[0.1, 0.7, 0.3, 0.2], [0.8, 0.4]], [[0, 0, 1, 0], [0, 0]])
132+
>>> meters.sum.dict()
133+
{'acc': 6.0, 'auroc': 8.4, 'auprc': 5.5}
134+
>>> meters.count.dict()
135+
{'acc': 10, 'auroc': 12, 'auprc': 10}
136+
>>> meters['auroc'].update([0.4, 0.8, 0.6, 0.2], [0, 1, 1, 0])
137+
>>> meters.avg.dict()
138+
{'acc': 0.6, 'auroc': 0.775, 'auprc': 0.55}
139+
>>> meters.update(dict(loss=""))
140+
Traceback (most recent call last):
141+
TypeError: MetricMeters.update() missing 1 required positional argument: 'target'
142+
"""
143+
144+
ignored_index: Optional[int] = None
145+
146+
def __init__(self, *args, ignored_index: int | None = None, **kwargs) -> None:
147+
self.setattr("ignored_index", ignored_index)
148+
for meter in args:
149+
if callable(meter):
150+
meter = MetricMeter(meter, ignored_index=self.ignored_index)
151+
if not isinstance(meter, MetricMeter):
152+
raise ValueError(f"Expected meter to be an instance of MetricMeter, but got {type(meter)}")
153+
for name, meter in kwargs.items():
154+
if callable(meter):
155+
kwargs[name] = meter = MetricMeter(meter, ignored_index=self.ignored_index)
156+
if not isinstance(meter, MetricMeter):
157+
raise ValueError(f"Expected {name} to be an instance of MetricMeter, but got {type(meter)}")
158+
super().__init__(*args, default_factory=None, **kwargs) # type: ignore[arg-type]
159+
160+
def update( # type: ignore[override] # pylint: disable=W0221
161+
self,
162+
input: Tensor | NestedTensor | Sequence, # pylint: disable=W0622
163+
target: Tensor | NestedTensor | Sequence,
164+
) -> None:
165+
r"""
166+
Updates the average and current value in all meters.
167+
168+
Args:
169+
input: Input values to compute the metrics.
170+
target: Target values to compute the metrics.
171+
"""
172+
173+
input, target = preprocess(input, target, ignored_index=self.ignored_index)
174+
for meter in self.values():
175+
meter.update(input, target)
176+
177+
def set(self, name: str, meter: MetricMeter | Callable) -> None: # type: ignore[override] # pylint: disable=W0237
178+
if callable(meter):
179+
meter = MetricMeter(meter, ignored_index=self.ignored_index)
180+
if not isinstance(meter, MetricMeter):
181+
raise ValueError(f"Expected meter to be an instance of MetricMeter, but got {type(meter)}")
182+
super().set(name, meter)
183+
184+
def __repr__(self):
185+
keys = tuple(i for i in self.keys())
186+
return f"{self.__class__.__name__}{keys}"
187+
188+
189+
class MultiTaskMetricMeters(MultiTaskAverageMeters):
190+
r"""
191+
Examples:
192+
>>> from danling.metrics.functional import accuracy
193+
>>> metrics = MultiTaskMetricMeters()
194+
>>> metrics.dataset1.cls = MetricMeters(acc=accuracy)
195+
>>> metrics.dataset2 = MetricMeters(acc=accuracy)
196+
>>> metrics
197+
MultiTaskMetricMeters(<class 'danling.metrics.metric_meter.MultiTaskMetricMeters'>,
198+
('dataset1'): MultiTaskMetricMeters(<class 'danling.metrics.metric_meter.MultiTaskMetricMeters'>,
199+
('cls'): MetricMeters('acc',)
200+
)
201+
('dataset2'): MetricMeters('acc',)
202+
)
203+
>>> metrics.update({"dataset1.cls": {"input": [0.2, 0.4, 0.5, 0.7], "target": [0, 1, 0, 1]}, "dataset2": {"input": [0.1, 0.4, 0.6, 0.8], "target": [1, 0, 0, 0]}})
204+
>>> f"{metrics:.4f}"
205+
'dataset1.cls: acc: 0.5000 (0.5000)\ndataset2: acc: 0.2500 (0.2500)'
206+
>>> metrics.setattr("return_average", True)
207+
>>> metrics.update({"dataset1.cls": {"input": [0.1, 0.4, 0.6, 0.8], "target": [0, 0, 1, 0]}, "dataset2": {"input": [0.2, 0.3, 0.5, 0.7], "target": [0, 0, 0, 1]}})
208+
>>> f"{metrics:.4f}"
209+
'dataset1.cls: acc: 0.7500 (0.6250)\ndataset2: acc: 0.7500 (0.5000)'
210+
""" # noqa: E501
211+
212+
def __init__(self, *args, **kwargs):
213+
super().__init__(*args, default_factory=MultiTaskMetricMeters, **kwargs)
214+
215+
def update( # type: ignore[override] # pylint: disable=W0221
216+
self,
217+
values: Mapping[str, Mapping[str, Tuple[Tensor | NestedTensor | Sequence, Tensor | NestedTensor | Sequence]]],
218+
) -> None:
219+
r"""
220+
Updates the average and current value in all meters.
221+
222+
Args:
223+
input: Input values to compute the metrics.
224+
target: Target values to compute the metrics.
225+
"""
226+
227+
for metric, value in values.items():
228+
if isinstance(value, Mapping):
229+
if metric not in self:
230+
raise ValueError(f"Metric {metric} not found in {self}")
231+
if isinstance(self[metric], MultiTaskMetricMeters):
232+
for met in self[metric].all_values():
233+
met.update(*value)
234+
elif isinstance(self[metric], (MetricMeters, MetricMeter)):
235+
self[metric].update(*value)
236+
else:
237+
raise ValueError(
238+
f"Expected {metric} to be an instance of MultiTaskMetricMeters, MetricMeters, "
239+
"or MetricMeter, but got {type(self[metric])}"
240+
)
241+
else:
242+
raise ValueError(f"Expected values to be a flat dictionary, but got {type(value)}")
243+
244+
# MultiTaskAverageMeters.get is hacked
245+
def get(self, name: Any, default=None) -> Any:
246+
return MultiTaskDict.get(self, name, default)
247+
248+
def set( # pylint: disable=W0237
249+
self,
250+
name: str,
251+
meter: MetricMeter | MetricMeters | Callable, # type: ignore[override]
252+
) -> None:
253+
if callable(meter):
254+
meter = MetricMeter(meter)
255+
if not isinstance(meter, (MetricMeter, MetricMeters)):
256+
raise ValueError(f"Expected meter to be an instance of MetricMeter or MetricMeters, but got {type(meter)}")
257+
super().set(name, meter)

danling/runner/base_runner.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
from chanfig import Config, FlatDict, NestedDict, Variable
3434

35-
from danling.metrics import AverageMeter, AverageMeters
35+
from danling.metrics import AverageMeter, AverageMeters, MetricMeters
3636
from danling.typing import File, PathStr
3737
from danling.utils import catch, ensure_dir, load, save
3838

@@ -157,7 +157,7 @@ class BaseRunner(metaclass=RunnerMeta): # pylint: disable=too-many-public-metho
157157
Attributes: logging:
158158
meters (AverageMeters | MultiTaskAverageMeters): Average meters.
159159
Initialised to `AverageMeters` by default.
160-
metrics (Metrics | MultiTaskMetrics | None): Metrics for evaluating.
160+
metrics (Metrics | MultiTaskMetrics | MetricMeters | None): Metrics for evaluating.
161161
logger:
162162
writer:
163163
@@ -186,7 +186,7 @@ class BaseRunner(metaclass=RunnerMeta): # pylint: disable=too-many-public-metho
186186

187187
results: NestedDict
188188
meters: AverageMeters
189-
metrics: Metrics | None = None
189+
metrics: Metrics | MetricMeters | None = None
190190
logger: logging.Logger | None = None
191191
writer: Any | None = None
192192

docs/docs/metrics/metric_meter.md

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
---
2+
authors:
3+
- Zhiyuan Chen
4+
date: 2022-05-04
5+
---
6+
7+
# MetricMeter
8+
9+
::: danling.metrics.metric_meter

docs/mkdocs.yml

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ nav:
2424
- Metrics:
2525
- Metrics: metrics/metrics.md
2626
- AverageMeter: metrics/average_meter.md
27+
- MetricMeter: metrics/metric_meter.md
2728
- Utils: metrics/utils.md
2829
- Utilities:
2930
- Decorator: utils/decorators.md

0 commit comments

Comments
 (0)