Skip to content

Commit cfdb3f3

Browse files
refine eval checking and logging
1 parent 3891566 commit cfdb3f3

File tree

3 files changed

+14
-23
lines changed

3 files changed

+14
-23
lines changed

ppsci/solver/eval.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from paddle import io
2626

2727
from ppsci.solver import printer
28-
from ppsci.utils import logger
2928
from ppsci.utils import misc
3029

3130
if TYPE_CHECKING:
@@ -91,11 +90,6 @@ def _eval_by_dataset(
9190
input_dict, label_dict, weight_dict = batch
9291
reader_cost = time.perf_counter() - reader_tic
9392

94-
# NOTE: eliminate first 5 step for warmup
95-
if iter_id == 5:
96-
for key in solver.eval_time_info:
97-
solver.eval_time_info[key].reset()
98-
9993
for v in input_dict.values():
10094
if hasattr(v, "stop_gradient"):
10195
v.stop_gradient = False
@@ -168,11 +162,6 @@ def _eval_by_dataset(
168162
for metric_name, metric_func in _validator.metric.items():
169163
# NOTE: compute metric with entire output and label
170164
metric_dict = metric_func(all_output, all_label)
171-
if metric_name in metric_dict_group:
172-
logger.warning(
173-
f"Metric name({metric_name}) already exists, please ensure "
174-
"all metric names are unique over all validators."
175-
)
176165
metric_dict_group[metric_name] = {
177166
k: float(v) for k, v in metric_dict.items()
178167
}
@@ -227,11 +216,6 @@ def _eval_by_batch(
227216
input_dict, label_dict, weight_dict = batch
228217
reader_cost = time.perf_counter() - reader_tic
229218

230-
# NOTE: eliminate first 5 step for warmup
231-
if iter_id == 5:
232-
for key in solver.eval_time_info:
233-
solver.eval_time_info[key].reset()
234-
235219
batch_size = next(iter(input_dict.values())).shape[0]
236220
for v in input_dict.values():
237221
if hasattr(v, "stop_gradient"):
@@ -287,11 +271,6 @@ def _eval_by_batch(
287271

288272
# concatenate all metric and discard metric of padded sample(s)
289273
for metric_name, metric_dict in metric_dict_group.items():
290-
# if metric_name in metric_dict_group:
291-
# logger.warning(
292-
# f"Metric name({metric_name}) already exists, please ensure "
293-
# "all metric names are unique over all validators."
294-
# )
295274
for var_name, metric_value in metric_dict.items():
296275
# NOTE: concat single metric(scalar) list into metric vector
297276
metric_value = paddle.concat(metric_value)[:num_samples]

ppsci/solver/printer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,8 @@ def log_eval_info(
149149
)
150150

151151
# reset time information after printing
152-
for key in solver.train_time_info:
153-
solver.train_time_info[key].reset()
152+
for key in solver.eval_time_info:
153+
solver.eval_time_info[key].reset()
154154

155155
# logger.scalar(
156156
# {

ppsci/solver/solver.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,18 @@ def __init__(
266266
f"{self.compute_metric_by_batch} when compute_metric_by_batch="
267267
f"{self.compute_metric_by_batch}."
268268
)
269+
# check metric name uniqueness over all validators
270+
_count = {}
271+
for _validator in validator.values():
272+
for metric_name in _validator.metric:
273+
if metric_name in _count:
274+
logger.warning(
275+
f"Metric name({metric_name}) is duplicated, please ensure "
276+
"all metric names are unique over all given validators."
277+
)
278+
_count[metric_name] = 1
279+
del _count
280+
269281
# whether set `stop_gradient=True` for every Tensor if no differentiation involved during evaluation
270282
if not cfg:
271283
self.eval_with_no_grad = eval_with_no_grad

0 commit comments

Comments
 (0)