Skip to content

Commit 48ef85b

Browse files
committed
rename AverageMeters to MultiTaskAverageMeter
Signed-off-by: Zhiyuan Chen <this@zyc.ai>
1 parent c0fabfb commit 48ef85b

12 files changed

+338
-312
lines changed

danling/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from danling import metrics, modules, optim, registry, runner, tensors, typing, utils
22

3-
from .metrics import AverageMeter, AverageMeters
3+
from .metrics import AverageMeter, MultiTaskAverageMeter
44
from .registry import GlobalRegistry, Registry
55
from .runner import AccelerateRunner, BaseRunner, TorchRunner
66
from .tensors import NestedTensor, PNTensor
@@ -22,7 +22,7 @@
2222
"GlobalRegistry",
2323
"Metrics",
2424
"AverageMeter",
25-
"AverageMeters",
25+
"MultiTaskAverageMeter",
2626
"NestedTensor",
2727
"PNTensor",
2828
"save",

danling/metrics/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from lazy_imports import try_import
44

5-
from .average_meters import AverageMeter, AverageMeters
5+
from .average_meter import AverageMeter, MultiTaskAverageMeter
66

77
with try_import():
88
from .functional import accuracy, auprc, auroc, matthews_corrcoef, pearson, r2_score, rmse, spearman
@@ -11,7 +11,7 @@
1111
__all__ = [
1212
"Metrics",
1313
"AverageMeter",
14-
"AverageMeters",
14+
"MultiTaskAverageMeter",
1515
"regression_metrics",
1616
"binary_metrics",
1717
"multiclass_metrics",

danling/metrics/average_meter.py

+262
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, Dict
4+
5+
from chanfig import NestedDict
6+
from torch import distributed as dist
7+
8+
from .multitask import MultiTaskDict
9+
from .utils import get_world_size
10+
11+
12+
class AverageMeter:
13+
r"""
14+
Computes and stores the average and current value.
15+
16+
Attributes:
17+
val: Results of current batch on current device.
18+
bat: Results of current batch on all devices.
19+
avg: Results of all results on all devices.
20+
sum: Sum of values.
21+
count: Number of values.
22+
23+
Examples:
24+
>>> meter = AverageMeter()
25+
>>> meter.update(0.7)
26+
>>> meter.val
27+
0.7
28+
>>> meter.avg
29+
0.7
30+
>>> meter.update(0.9)
31+
>>> meter.val
32+
0.9
33+
>>> meter.avg
34+
0.8
35+
>>> meter.sum
36+
1.6
37+
>>> meter.count
38+
2
39+
>>> meter.reset()
40+
>>> meter.val
41+
0
42+
>>> meter.avg
43+
nan
44+
"""
45+
46+
val: float = 0
47+
n: float = 1
48+
sum: float = 0
49+
count: float = 0
50+
51+
def __init__(self) -> None:
52+
self.reset()
53+
54+
def reset(self) -> None:
55+
r"""
56+
Resets the meter.
57+
58+
Examples:
59+
>>> meter = AverageMeter()
60+
>>> meter.update(0.7)
61+
>>> meter.val
62+
0.7
63+
>>> meter.avg
64+
0.7
65+
>>> meter.reset()
66+
>>> meter.val
67+
0
68+
>>> meter.avg
69+
nan
70+
"""
71+
72+
self.val = 0
73+
self.n = 1
74+
self.sum = 0
75+
self.count = 0
76+
77+
def update(self, value, n: float = 1) -> None:
78+
r"""
79+
Updates the average and current value in the meter.
80+
81+
Args:
82+
value: Value to be added to the average.
83+
n: Number of values to be added.
84+
85+
Examples:
86+
>>> meter = AverageMeter()
87+
>>> meter.update(0.7)
88+
>>> meter.val
89+
0.7
90+
>>> meter.avg
91+
0.7
92+
>>> meter.update(0.9)
93+
>>> meter.val
94+
0.9
95+
>>> meter.avg
96+
0.8
97+
>>> meter.sum
98+
1.6
99+
>>> meter.count
100+
2
101+
"""
102+
103+
self.val = value
104+
self.n = n
105+
self.sum += value * n
106+
self.count += n
107+
108+
def value(self):
109+
return self.val
110+
111+
def batch(self):
112+
world_size = get_world_size()
113+
if world_size == 1:
114+
return self.val / self.n if self.n != 0 else float("nan")
115+
synced_tuple = [None for _ in range(world_size)]
116+
dist.all_gather_object(synced_tuple, (self.val * self.n, self.n))
117+
val, n = zip(*synced_tuple)
118+
count = sum(n)
119+
if count == 0:
120+
return float("nan")
121+
return sum(val) / count
122+
123+
def average(self):
124+
world_size = get_world_size()
125+
if world_size == 1:
126+
return self.sum / self.count if self.count != 0 else float("nan")
127+
synced_tuple = [None for _ in range(world_size)]
128+
dist.all_gather_object(synced_tuple, (self.sum, self.count))
129+
val, n = zip(*synced_tuple)
130+
count = sum(n)
131+
if count == 0:
132+
return float("nan")
133+
return sum(val) / count
134+
135+
@property
136+
def bat(self):
137+
return self.batch()
138+
139+
@property
140+
def avg(self):
141+
return self.average()
142+
143+
def __format__(self, format_spec) -> str:
144+
return f"{self.val.__format__(format_spec)} ({self.avg.__format__(format_spec)})"
145+
146+
147+
class MultiTaskAverageMeter(MultiTaskDict):
148+
"""
149+
Examples:
150+
>>> meters = MultiTaskAverageMeter()
151+
>>> meters.update({"loss": 0.6, "dataset1.cls.auroc": 0.7, "dataset1.reg.r2": 0.8, "dataset2.r2": 0.9})
152+
>>> print(f"{meters:.4f}")
153+
loss: 0.6000 (0.6000)
154+
dataset1.cls.auroc: 0.7000 (0.7000)
155+
dataset1.reg.r2: 0.8000 (0.8000)
156+
dataset2.r2: 0.9000 (0.9000)
157+
>>> meters.update({"loss": {"value": 0.9, "n": 1}})
158+
>>> print(f"{meters:.4f}")
159+
loss: 0.9000 (0.7500)
160+
dataset1.cls.auroc: 0.7000 (0.7000)
161+
dataset1.reg.r2: 0.8000 (0.8000)
162+
dataset2.r2: 0.9000 (0.9000)
163+
>>> meters.sum.dict()
164+
{'loss': 1.5, 'dataset1': {'cls': {'auroc': 0.7}, 'reg': {'r2': 0.8}}, 'dataset2': {'r2': 0.9}}
165+
>>> meters.count.dict()
166+
{'loss': 2, 'dataset1': {'cls': {'auroc': 1}, 'reg': {'r2': 1}}, 'dataset2': {'r2': 1}}
167+
>>> meters.reset()
168+
>>> print(f"{meters:.4f}")
169+
loss: 0.0000 (nan)
170+
dataset1.cls.auroc: 0.0000 (nan)
171+
dataset1.reg.r2: 0.0000 (nan)
172+
dataset2.r2: 0.0000 (nan)
173+
>>> meters = MultiTaskAverageMeter(return_average=True)
174+
>>> meters.update({"loss": 0.6, "dataset1.a.auroc": 0.7, "dataset1.b.auroc": 0.8, "dataset2.auroc": 0.9})
175+
>>> print(f"{meters:.4f}")
176+
loss: 0.6000 (0.6000)
177+
dataset1.a.auroc: 0.7000 (0.7000)
178+
dataset1.b.auroc: 0.8000 (0.8000)
179+
dataset2.auroc: 0.9000 (0.9000)
180+
>>> meters.update({"loss": 0.9, "dataset1.a.auroc": 0.8, "dataset1.b.auroc": 0.9, "dataset2.auroc": 1.0})
181+
>>> print(f"{meters:.4f}")
182+
loss: 0.9000 (0.7500)
183+
dataset1.a.auroc: 0.8000 (0.7500)
184+
dataset1.b.auroc: 0.9000 (0.8500)
185+
dataset2.auroc: 1.0000 (0.9500)
186+
"""
187+
188+
@property
189+
def sum(self) -> NestedDict[str, float]:
190+
return NestedDict({key: meter.sum for key, meter in self.all_items()})
191+
192+
@property
193+
def count(self) -> NestedDict[str, int]:
194+
return NestedDict({key: meter.count for key, meter in self.all_items()})
195+
196+
def update(self, values: Dict, *, n: int = 1) -> None: # pylint: disable=W0237
197+
r"""
198+
Updates the average and current value in all meters.
199+
200+
Args:
201+
values: Dict of values to be added to the average.
202+
n: Number of values to be added.
203+
204+
Raises:
205+
ValueError: If the value is not an instance of (int, float, Mapping).
206+
207+
Examples:
208+
>>> meters = MultiTaskAverageMeter()
209+
>>> meters.update({"loss": 0.6, "dataset1.cls.auroc": 0.7, "dataset1.reg.r2": 0.8, "dataset2.r2": 0.9})
210+
>>> meters.sum.dict()
211+
{'loss': 0.6, 'dataset1': {'cls': {'auroc': 0.7}, 'reg': {'r2': 0.8}}, 'dataset2': {'r2': 0.9}}
212+
>>> meters.count.dict()
213+
{'loss': 1, 'dataset1': {'cls': {'auroc': 1}, 'reg': {'r2': 1}}, 'dataset2': {'r2': 1}}
214+
>>> meters.update({"loss": {"value": 0.9, "n": 1}})
215+
>>> meters.sum.dict()
216+
{'loss': 1.5, 'dataset1': {'cls': {'auroc': 0.7}, 'reg': {'r2': 0.8}}, 'dataset2': {'r2': 0.9}}
217+
>>> meters.count.dict()
218+
{'loss': 2, 'dataset1': {'cls': {'auroc': 1}, 'reg': {'r2': 1}}, 'dataset2': {'r2': 1}}
219+
>>> meters.update({"loss": 0.8, "dataset1.cls.auroc": 0.9, "dataset1.reg.r2": 0.8, "dataset2.r2": 0.7})
220+
>>> meters.sum.dict()
221+
{'loss': 2.3, 'dataset1': {'cls': {'auroc': 1.6}, 'reg': {'r2': 1.6}}, 'dataset2': {'r2': 1.6}}
222+
>>> meters.count.dict()
223+
{'loss': 3, 'dataset1': {'cls': {'auroc': 2}, 'reg': {'r2': 2}}, 'dataset2': {'r2': 2}}
224+
>>> meters.update({"dataset1.cls.auroc": 0.7, "dataset1.reg.r2": 0.7, "dataset2.r2": 0.9})
225+
>>> meters.sum.dict()
226+
{'loss': 2.3, 'dataset1': {'cls': {'auroc': 2.3}, 'reg': {'r2': 2.3}}, 'dataset2': {'r2': 2.5}}
227+
>>> meters.count.dict()
228+
{'loss': 3, 'dataset1': {'cls': {'auroc': 3}, 'reg': {'r2': 3}}, 'dataset2': {'r2': 3}}
229+
>>> meters.update({"dataset1": {"cls.auroc": 0.9}, "dataset1.reg.r2": 0.8, "dataset2.r2": 0.9})
230+
Traceback (most recent call last):
231+
ValueError: Expected values to be int, float, or a flat dictionary, but got <class 'dict'>
232+
This is likely due to nested dictionary in the values.
233+
Nested dictionaries cannot be processed due to the method's design, which uses Mapping to pass both value and count ('n'). Ensure your input is a flat dictionary or a single value.
234+
>>> meters.update(dict(loss=""))
235+
Traceback (most recent call last):
236+
ValueError: Expected values to be int, float, or a flat dictionary, but got <class 'str'>
237+
""" # noqa: E501
238+
239+
for meter, value in values.items():
240+
if isinstance(value, (int, float)):
241+
self[meter].update(value, n)
242+
elif isinstance(value, Dict):
243+
value.setdefault("n", n)
244+
try:
245+
self[meter].update(**value)
246+
except TypeError:
247+
raise ValueError(
248+
f"Expected values to be int, float, or a flat dictionary, but got {type(value)}\n"
249+
"This is likely due to nested dictionary in the values.\n"
250+
"Nested dictionaries cannot be processed due to the method's design, which uses Mapping "
251+
"to pass both value and count ('n'). Ensure your input is a flat dictionary or a single value."
252+
) from None
253+
else:
254+
raise ValueError(f"Expected values to be int, float, or a flat dictionary, but got {type(value)}")
255+
256+
# eval hack, as the default_factory must not be set to make `NestedDict` happy
257+
# this have some side effects, it will break attribute style intermediate nested dict auto creation
258+
# but everything has a price
259+
def get(self, name: Any, default=None) -> Any:
260+
if not name.startswith("_") and not name.endswith("_"):
261+
return self.setdefault(name, AverageMeter())
262+
return super().get(name, default)

0 commit comments

Comments
 (0)