diff --git a/mmcv/runner/__init__.py b/mmcv/runner/__init__.py index e9b41a64c4..bba39f5c90 100644 --- a/mmcv/runner/__init__.py +++ b/mmcv/runner/__init__.py @@ -15,8 +15,9 @@ Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook, GradientCumulativeOptimizerHook, Hook, IterTimerHook, LoggerHook, MlflowLoggerHook, NeptuneLoggerHook, - OptimizerHook, PaviLoggerHook, SyncBuffersHook, - TensorboardLoggerHook, TextLoggerHook, WandbLoggerHook) + OptimizerHook, PaviLoggerHook, SegmindLoggerHook, + SyncBuffersHook, TensorboardLoggerHook, TextLoggerHook, + WandbLoggerHook) from .hooks.lr_updater import StepLrUpdaterHook # noqa from .hooks.lr_updater import (CosineAnnealingLrUpdaterHook, CosineRestartLrUpdaterHook, CyclicLrUpdaterHook, @@ -60,5 +61,6 @@ 'allreduce_params', 'LossScaler', 'CheckpointLoader', 'BaseModule', '_load_checkpoint_with_prefix', 'EvalHook', 'DistEvalHook', 'Sequential', 'ModuleDict', 'ModuleList', 'GradientCumulativeOptimizerHook', - 'GradientCumulativeFp16OptimizerHook', 'DefaultRunnerConstructor' + 'GradientCumulativeFp16OptimizerHook', 'DefaultRunnerConstructor', + 'SegmindLoggerHook' ] diff --git a/mmcv/runner/hooks/__init__.py b/mmcv/runner/hooks/__init__.py index 671bb0c2e9..121073eb58 100644 --- a/mmcv/runner/hooks/__init__.py +++ b/mmcv/runner/hooks/__init__.py @@ -6,8 +6,8 @@ from .hook import HOOKS, Hook from .iter_timer import IterTimerHook from .logger import (DvcliveLoggerHook, LoggerHook, MlflowLoggerHook, - NeptuneLoggerHook, PaviLoggerHook, TensorboardLoggerHook, - TextLoggerHook, WandbLoggerHook) + NeptuneLoggerHook, PaviLoggerHook, SegmindLoggerHook, + TensorboardLoggerHook, TextLoggerHook, WandbLoggerHook) from .lr_updater import (CosineAnnealingLrUpdaterHook, CosineRestartLrUpdaterHook, CyclicLrUpdaterHook, ExpLrUpdaterHook, FixedLrUpdaterHook, @@ -38,5 +38,6 @@ 'StepMomentumUpdaterHook', 'CosineAnnealingMomentumUpdaterHook', 'CyclicMomentumUpdaterHook', 'OneCycleMomentumUpdaterHook', 'SyncBuffersHook', 'EMAHook', 'EvalHook', 'DistEvalHook', 'ProfilerHook', - 'GradientCumulativeOptimizerHook', 'GradientCumulativeFp16OptimizerHook' + 'GradientCumulativeOptimizerHook', 'GradientCumulativeFp16OptimizerHook', + 'SegmindLoggerHook' ] diff --git a/mmcv/runner/hooks/logger/__init__.py b/mmcv/runner/hooks/logger/__init__.py index a0b6b34564..8ce580d392 100644 --- a/mmcv/runner/hooks/logger/__init__.py +++ b/mmcv/runner/hooks/logger/__init__.py @@ -4,6 +4,7 @@ from .mlflow import MlflowLoggerHook from .neptune import NeptuneLoggerHook from .pavi import PaviLoggerHook +from .segmind import SegmindLoggerHook from .tensorboard import TensorboardLoggerHook from .text import TextLoggerHook from .wandb import WandbLoggerHook @@ -11,5 +12,5 @@ __all__ = [ 'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook', 'TensorboardLoggerHook', 'TextLoggerHook', 'WandbLoggerHook', - 'NeptuneLoggerHook', 'DvcliveLoggerHook' + 'NeptuneLoggerHook', 'DvcliveLoggerHook', 'SegmindLoggerHook' ] diff --git a/mmcv/runner/hooks/logger/segmind.py b/mmcv/runner/hooks/logger/segmind.py new file mode 100644 index 0000000000..e262c7c1aa --- /dev/null +++ b/mmcv/runner/hooks/logger/segmind.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ...dist_utils import master_only +from ..hook import HOOKS +from .base import LoggerHook + + +@HOOKS.register_module() +class SegmindLoggerHook(LoggerHook): + """Class to log metrics to Segmind. + + It requires `Segmind`_ to be installed. + + Args: + interval (int): Logging interval (every k iterations). Default: 10. + ignore_last (bool): Ignore the log of last iterations in each epoch + if less than `interval`. Default True. + reset_flag (bool): Whether to clear the output buffer after logging. + Default False. + by_epoch (bool): Whether EpochBasedRunner is used. Default True. + + .. _Segmind: + https://docs.segmind.com/python-library + """ + + def __init__(self, + interval=10, + ignore_last=True, + reset_flag=False, + by_epoch=True): + super(SegmindLoggerHook, self).__init__(interval, ignore_last, + reset_flag, by_epoch) + self.import_segmind() + + def import_segmind(self): + try: + import segmind + except ImportError: + raise ImportError( + "Please run 'pip install segmind' to install segmind") + self.log_metrics = segmind.tracking.fluent.log_metrics + self.mlflow_log = segmind.utils.logging_utils.try_mlflow_log + + @master_only + def log(self, runner): + tags = self.get_loggable_tags(runner) + if tags: + # logging metrics to segmind + self.mlflow_log( + self.log_metrics, tags, step=runner.epoch, epoch=runner.epoch) diff --git a/tests/test_runner/test_hooks.py b/tests/test_runner/test_hooks.py index 8f56cf2bd7..a729e08ed9 100644 --- a/tests/test_runner/test_hooks.py +++ b/tests/test_runner/test_hooks.py @@ -21,12 +21,15 @@ from torch.utils.data import DataLoader from mmcv.fileio.file_client import PetrelBackend +# yapf: disable from mmcv.runner import (CheckpointHook, DvcliveLoggerHook, EMAHook, Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook, GradientCumulativeOptimizerHook, IterTimerHook, MlflowLoggerHook, NeptuneLoggerHook, OptimizerHook, - PaviLoggerHook, WandbLoggerHook, build_runner) + PaviLoggerHook, SegmindLoggerHook, WandbLoggerHook, + build_runner) +# yapf: enable from mmcv.runner.fp16_utils import auto_fp16 from mmcv.runner.hooks.hook import HOOKS, Hook from mmcv.runner.hooks.lr_updater import (CosineRestartLrUpdaterHook, @@ -1400,6 +1403,25 @@ def test_mlflow_hook(log_model): assert not hook.mlflow_pytorch.log_model.called +def test_segmind_hook(): + sys.modules['segmind'] = MagicMock() + runner = _build_demo_runner() + hook = SegmindLoggerHook() + loader = DataLoader(torch.ones((5, 2))) + + runner.register_hook(hook) + runner.run([loader, loader], [('train', 1), ('val', 1)]) + shutil.rmtree(runner.work_dir) + + hook.mlflow_log.assert_called_with( + hook.log_metrics, { + 'learning_rate': 0.02, + 'momentum': 0.95 + }, + step=runner.epoch, + epoch=runner.epoch) + + def test_wandb_hook(): sys.modules['wandb'] = MagicMock() runner = _build_demo_runner()