Skip to content

Commit fa1b153

Browse files
committed
fix preprocess in MetricMeter
Signed-off-by: Zhiyuan Chen <this@zyc.ai>
1 parent ec53dfe commit fa1b153

File tree

4 files changed

+32
-34
lines changed

4 files changed

+32
-34
lines changed

danling/metrics/average_meter.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ class AverageMeter:
5050
>>> meter.update(0.7)
5151
>>> meter.val
5252
0.7
53+
>>> meter.bat
54+
0.7
5355
>>> meter.avg
5456
0.7
5557
>>> meter.update(0.9)
@@ -167,14 +169,11 @@ class AverageMeters(MetricsDict):
167169
'loss: 0.0000 (nan)\tauroc: 0.0000 (nan)\tr2: 0.0000 (nan)'
168170
"""
169171

170-
def __init__(self, *args, default_factory: Type[AverageMeter] = AverageMeter, **kwargs) -> None:
171-
for meter in args:
172-
if not isinstance(meter, AverageMeter):
173-
raise ValueError(f"Expected meter to be an instance of AverageMeter, but got {type(meter)}")
174-
for name, meter in kwargs.items():
172+
def __init__(self, default_factory: Type[AverageMeter] = AverageMeter, **meters) -> None:
173+
for name, meter in meters.items():
175174
if not isinstance(meter, AverageMeter):
176175
raise ValueError(f"Expected {name} to be an instance of AverageMeter, but got {type(meter)}")
177-
super().__init__(*args, default_factory=default_factory, **kwargs)
176+
super().__init__(default_factory=default_factory, **meters)
178177

179178
@property
180179
def sum(self) -> FlatDict[str, float]:
@@ -288,4 +287,4 @@ def set(self, name: str, meter: AverageMeter | AverageMeters) -> None: # pylint
288287
raise ValueError(
289288
f"Expected meter to be an instance of AverageMeter or AverageMeters, but got {type(meter)}"
290289
)
291-
super().set(name, meter)
290+
super().set(name, meter, convert_mapping=False)

danling/metrics/metric_meter.py

+18-19
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,11 @@ class MetricMeter(AverageMeter):
7070
"""
7171

7272
metric: Callable
73-
preprocess: Callable
73+
preprocess: Optional[Callable] = None
7474
ignored_index: Optional[int] = None
7575

7676
def __init__(
77-
self, metric: Callable, preprocess: Callable = default_preprocess, ignored_index: int | None = None
77+
self, metric: Callable, preprocess: Callable | None = default_preprocess, ignored_index: int | None = None
7878
) -> None:
7979
self.metric = metric
8080
self.preprocess = preprocess
@@ -93,7 +93,8 @@ def update( # type: ignore[override] # pylint: disable=W0237
9393
value: Value to be added to the average.
9494
n: Number of values to be added.
9595
"""
96-
input, target = self.preprocess(input, target, ignored_index=self.ignored_index)
96+
if self.preprocess is not None:
97+
input, target = self.preprocess(input, target, ignored_index=self.ignored_index)
9798
n = len(input)
9899
super().update(self.metric(input, target).item() * n, n=n)
99100

@@ -135,27 +136,20 @@ class MetricMeters(AverageMeters):
135136
TypeError: ...update() missing 1 required positional argument: 'target'
136137
"""
137138

138-
preprocess: Callable
139+
preprocess = None
139140
ignored_index = None
140141

141142
def __init__(
142-
self, *args, preprocess: Callable = default_preprocess, ignored_index: int | None = None, **kwargs
143+
self, preprocess: Callable | None = default_preprocess, ignored_index: int | None = None, **meters
143144
) -> None:
144145
self.setattr("preprocess", preprocess)
145146
self.setattr("ignored_index", ignored_index)
146-
for meter in args:
147+
for name, meter in meters.items():
147148
if callable(meter):
148-
meter = MetricMeter(meter, ignored_index=self.ignored_index)
149-
if not isinstance(meter, MetricMeter):
150-
raise ValueError(f"Expected meter to be an instance of MetricMeter, but got {type(meter)}")
151-
for name, meter in kwargs.items():
152-
if callable(meter):
153-
kwargs[name] = meter = MetricMeter(meter, ignored_index=self.ignored_index)
149+
meters[name] = meter = MetricMeter(meter, preprocess=None, ignored_index=self.ignored_index)
154150
if not isinstance(meter, MetricMeter):
155151
raise ValueError(f"Expected {name} to be an instance of MetricMeter, but got {type(meter)}")
156-
if ignored_index is not None:
157-
self.setattr("ignored_index", ignored_index)
158-
super().__init__(*args, default_factory=None, **kwargs) # type: ignore[arg-type]
152+
super().__init__(default_factory=None, **meters) # type: ignore[arg-type]
159153

160154
def update( # type: ignore[override] # pylint: disable=W0221
161155
self,
@@ -170,13 +164,14 @@ def update( # type: ignore[override] # pylint: disable=W0221
170164
target: Target values to compute the metrics.
171165
"""
172166

173-
input, target = self.preprocess(input, target, ignored_index=self.ignored_index)
167+
if self.preprocess is not None:
168+
input, target = self.preprocess(input, target, ignored_index=self.ignored_index)
174169
for meter in self.values():
175170
meter.update(input, target)
176171

177172
def set(self, name: str, meter: MetricMeter | Callable) -> None: # type: ignore[override] # pylint: disable=W0237
178173
if callable(meter):
179-
meter = MetricMeter(meter, ignored_index=self.ignored_index)
174+
meter = MetricMeter(meter, preprocess=None, ignored_index=self.ignored_index)
180175
if not isinstance(meter, MetricMeter):
181176
raise ValueError(f"Expected meter to be an instance of MetricMeter, but got {type(meter)}")
182177
super().set(name, meter)
@@ -260,9 +255,13 @@ def set( # pylint: disable=W0237
260255
name: str,
261256
metric: MetricMeter | MetricMeters | Callable, # type: ignore[override]
262257
) -> None:
263-
if callable(metric):
258+
from .metrics import Metrics
259+
260+
if isinstance(metric, Metrics):
261+
metric = MetricMeters(preprocess=metric.preprocess, ignored_index=metric.ignored_index, **metric.metrics)
262+
elif callable(metric):
264263
metric = MetricMeter(metric)
265-
if not isinstance(metric, (MetricMeter, MetricMeters)):
264+
elif not isinstance(metric, (MetricMeter, MetricMeters)):
266265
raise ValueError(
267266
f"Expected {metric} to be an instance of MetricMeter or MetricMeters, but got {type(metric)}"
268267
)

danling/metrics/metrics.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ class Metrics(Metric):
119119
"""
120120

121121
metrics: FlatDict[str, Callable]
122-
preprocess: Callable
122+
preprocess: Optional[Callable] = None
123123
ignored_index: Optional[int] = None
124124
_input: Tensor
125125
_target: Tensor
@@ -182,6 +182,7 @@ def update(self, input: Tensor | NestedTensor | Sequence, target: Tensor | Neste
182182
else:
183183
raise ValueError(f"Unknown input and target: {input}, {target}")
184184
self.flatten = True
185+
# If the batch has only 1 sample at the end of an epoch, we need to flatten the input and target
185186
elif self.flatten:
186187
target = target.flatten()
187188
input = input.flatten() if input.numel() == target.numel() else input.view(*target.shape, -1)
@@ -196,6 +197,8 @@ def update(self, input: Tensor | NestedTensor | Sequence, target: Tensor | Neste
196197
if self.world_size > 1:
197198
input, target = self._sync(input), self._sync(target)
198199
input, target = input.detach().to(self.device), target.detach().to(self.device)
200+
if self.preprocess is not None:
201+
input, target = self.preprocess(input, target, ignored_index=self.ignored_index)
199202
self._input = input
200203
self._target = target
201204
self._inputs = torch.cat([self._inputs, input]).to(input.dtype)
@@ -228,9 +231,8 @@ def calculate(self, input: Tensor, target: Tensor) -> NestedDict[str, flist | fl
228231
):
229232
return NestedDict({name: nan for name in self.metrics.keys()})
230233
ret = NestedDict()
231-
input, target = self.preprocess(input, target, ignored_index=self.ignored_index)
232234
for name, metric in self.metrics.items():
233-
score = self._calculate(metric, input, target, preprocess=False)
235+
score = self._calculate(metric, input, target)
234236
if isinstance(score, Mapping):
235237
if self.merge_dict:
236238
ret.merge(score)
@@ -242,9 +244,7 @@ def calculate(self, input: Tensor, target: Tensor) -> NestedDict[str, flist | fl
242244
return ret
243245

244246
@torch.inference_mode()
245-
def _calculate(self, metric, input: Tensor, target: Tensor, preprocess: bool = True) -> flist | float:
246-
if preprocess:
247-
input, target = self.preprocess(input, target, ignored_index=self.ignored_index)
247+
def _calculate(self, metric, input: Tensor, target: Tensor) -> flist | float:
248248
score = metric(input, target)
249249
if isinstance(score, Tensor):
250250
return score.item() if score.numel() == 1 else flist(score.tolist())

tests/metrics/test_metrics.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def test_nested_tensor_regression(self):
209209
targets.extend(target_list)
210210
pred_nt, target_nt = NestedTensor(pred_list), NestedTensor(target_list)
211211
pred, target = torch.cat(pred_list), torch.cat(target_list)
212-
metrics.update(pred_nt, target_nt)
212+
metrics.update(pred_nt.tensor, target_nt)
213213
for metric in metric_map.values():
214214
metric.update(pred.view(-1, num_outputs), target.view(-1, num_outputs))
215215
value, average = metrics.value(), metrics.average()
@@ -278,7 +278,7 @@ def test_nested_tensor_binary(self):
278278
targets.extend(target_list)
279279
pred_nt, target_nt = NestedTensor(pred_list), NestedTensor(target_list)
280280
pred, target = torch.cat(pred_list), torch.cat(target_list)
281-
metrics.update(pred_nt, target_nt)
281+
metrics.update(pred_nt, target_nt.tensor)
282282
merge_metrics.update(pred_nt, target_nt)
283283
for metric in metric_map.values():
284284
metric.update(pred, target)

0 commit comments

Comments
 (0)