Skip to content

Commit

Permalink
[Enhance] Support non-scalar type metric value.
Browse files Browse the repository at this point in the history
  • Loading branch information
mzr1996 committed Jan 6, 2023
1 parent f10b5ce commit 110e364
Showing 1 changed file with 30 additions and 3 deletions.
33 changes: 30 additions & 3 deletions mmengine/hooks/runtime_info_hook.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Union
from typing import Any, Dict, Optional, Union

import numpy as np
import torch

from mmengine.registry import HOOKS
from mmengine.utils import get_git_hash
Expand All @@ -9,6 +12,24 @@
DATA_BATCH = Optional[Union[dict, tuple, list]]


def _is_scalar(value: Any) -> bool:
"""Determine the value is a scalar type value.
Args:
value (Any): value of log.
Returns:
bool: whether the value is a scalar type value.
"""
if isinstance(value, np.ndarray):
return value.size == 1
elif isinstance(value, (int, float)):
return True
elif isinstance(value, torch.Tensor):
return value.numel() == 1
return False


@HOOKS.register_module()
class RuntimeInfoHook(Hook):
"""A hook that updates runtime information into message hub.
Expand Down Expand Up @@ -112,7 +133,10 @@ def after_val_epoch(self,
"""
if metrics is not None:
for key, value in metrics.items():
runner.message_hub.update_scalar(f'val/{key}', value)
if _is_scalar(value):
runner.message_hub.update_scalar(f'val/{key}', value)
else:
runner.logger.info(f'{key}:\n{value}')

def after_test_epoch(self,
runner,
Expand All @@ -128,4 +152,7 @@ def after_test_epoch(self,
"""
if metrics is not None:
for key, value in metrics.items():
runner.message_hub.update_scalar(f'test/{key}', value)
if _is_scalar(value):
runner.message_hub.update_scalar(f'test/{key}', value)
else:
runner.logger.info(f'{key}:\n{value}')

0 comments on commit 110e364

Please sign in to comment.