@@ -70,11 +70,11 @@ class MetricMeter(AverageMeter):
70
70
"""
71
71
72
72
metric : Callable
73
- preprocess : Callable
73
+ preprocess : Optional [ Callable ] = None
74
74
ignored_index : Optional [int ] = None
75
75
76
76
def __init__ (
77
- self , metric : Callable , preprocess : Callable = default_preprocess , ignored_index : int | None = None
77
+ self , metric : Callable , preprocess : Callable | None = default_preprocess , ignored_index : int | None = None
78
78
) -> None :
79
79
self .metric = metric
80
80
self .preprocess = preprocess
@@ -93,7 +93,8 @@ def update( # type: ignore[override] # pylint: disable=W0237
93
93
value: Value to be added to the average.
94
94
n: Number of values to be added.
95
95
"""
96
- input , target = self .preprocess (input , target , ignored_index = self .ignored_index )
96
+ if self .preprocess is not None :
97
+ input , target = self .preprocess (input , target , ignored_index = self .ignored_index )
97
98
n = len (input )
98
99
super ().update (self .metric (input , target ).item () * n , n = n )
99
100
@@ -135,27 +136,20 @@ class MetricMeters(AverageMeters):
135
136
TypeError: ...update() missing 1 required positional argument: 'target'
136
137
"""
137
138
138
- preprocess : Callable
139
+ preprocess = None
139
140
ignored_index = None
140
141
141
142
def __init__ (
142
- self , * args , preprocess : Callable = default_preprocess , ignored_index : int | None = None , ** kwargs
143
+ self , preprocess : Callable | None = default_preprocess , ignored_index : int | None = None , ** meters
143
144
) -> None :
144
145
self .setattr ("preprocess" , preprocess )
145
146
self .setattr ("ignored_index" , ignored_index )
146
- for meter in args :
147
+ for name , meter in meters . items () :
147
148
if callable (meter ):
148
- meter = MetricMeter (meter , ignored_index = self .ignored_index )
149
- if not isinstance (meter , MetricMeter ):
150
- raise ValueError (f"Expected meter to be an instance of MetricMeter, but got { type (meter )} " )
151
- for name , meter in kwargs .items ():
152
- if callable (meter ):
153
- kwargs [name ] = meter = MetricMeter (meter , ignored_index = self .ignored_index )
149
+ meters [name ] = meter = MetricMeter (meter , preprocess = None , ignored_index = self .ignored_index )
154
150
if not isinstance (meter , MetricMeter ):
155
151
raise ValueError (f"Expected { name } to be an instance of MetricMeter, but got { type (meter )} " )
156
- if ignored_index is not None :
157
- self .setattr ("ignored_index" , ignored_index )
158
- super ().__init__ (* args , default_factory = None , ** kwargs ) # type: ignore[arg-type]
152
+ super ().__init__ (default_factory = None , ** meters ) # type: ignore[arg-type]
159
153
160
154
def update ( # type: ignore[override] # pylint: disable=W0221
161
155
self ,
@@ -170,13 +164,14 @@ def update( # type: ignore[override] # pylint: disable=W0221
170
164
target: Target values to compute the metrics.
171
165
"""
172
166
173
- input , target = self .preprocess (input , target , ignored_index = self .ignored_index )
167
+ if self .preprocess is not None :
168
+ input , target = self .preprocess (input , target , ignored_index = self .ignored_index )
174
169
for meter in self .values ():
175
170
meter .update (input , target )
176
171
177
172
def set (self , name : str , meter : MetricMeter | Callable ) -> None : # type: ignore[override] # pylint: disable=W0237
178
173
if callable (meter ):
179
- meter = MetricMeter (meter , ignored_index = self .ignored_index )
174
+ meter = MetricMeter (meter , preprocess = None , ignored_index = self .ignored_index )
180
175
if not isinstance (meter , MetricMeter ):
181
176
raise ValueError (f"Expected meter to be an instance of MetricMeter, but got { type (meter )} " )
182
177
super ().set (name , meter )
@@ -260,9 +255,13 @@ def set( # pylint: disable=W0237
260
255
name : str ,
261
256
metric : MetricMeter | MetricMeters | Callable , # type: ignore[override]
262
257
) -> None :
263
- if callable (metric ):
258
+ from .metrics import Metrics
259
+
260
+ if isinstance (metric , Metrics ):
261
+ metric = MetricMeters (preprocess = metric .preprocess , ignored_index = metric .ignored_index , ** metric .metrics )
262
+ elif callable (metric ):
264
263
metric = MetricMeter (metric )
265
- if not isinstance (metric , (MetricMeter , MetricMeters )):
264
+ elif not isinstance (metric , (MetricMeter , MetricMeters )):
266
265
raise ValueError (
267
266
f"Expected { metric } to be an instance of MetricMeter or MetricMeters, but got { type (metric )} "
268
267
)
0 commit comments