Skip to content

Commit

Permalink
[Enhancement] Add ability to pass logger instance to frameworks (open…
Browse files Browse the repository at this point in the history
…-mmlab#2317)

* Add ability to pass logger instance to frameworks

* refine docstring

* Update mmcv/runner/hooks/logger/dvclive.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
  • Loading branch information
2 people authored and cathyzhang222 committed Oct 20, 2022
1 parent 7eb5356 commit c631779
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
12 changes: 8 additions & 4 deletions mmcv/runner/hooks/logger/dvclive.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ class DvcliveLoggerHook(LoggerHook):
reset_flag (bool): Whether to clear the output buffer after logging.
Default: False.
by_epoch (bool): Whether EpochBasedRunner is used. Default: True.
kwargs: Arguments for instantiating `Live`_.
dvclive (Live, optional): An instance of the `Live`_ logger to use
instead of initializing a new one internally. Defaults to None.
kwargs: Arguments for instantiating `Live`_ (ignored if `dvclive` is
provided).
.. _dvclive:
https://dvc.org/doc/dvclive
Expand All @@ -37,18 +40,19 @@ def __init__(self,
ignore_last: bool = True,
reset_flag: bool = False,
by_epoch: bool = True,
dvclive=None,
**kwargs):
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.model_file = model_file
self.import_dvclive(**kwargs)
self._import_dvclive(dvclive, **kwargs)

def import_dvclive(self, **kwargs) -> None:
def _import_dvclive(self, dvclive=None, **kwargs) -> None:
try:
from dvclive import Live
except ImportError:
raise ImportError(
'Please run "pip install dvclive" to install dvclive')
self.dvclive = Live(**kwargs)
self.dvclive = dvclive if dvclive is not None else Live(**kwargs)

@master_only
def log(self, runner) -> None:
Expand Down
11 changes: 10 additions & 1 deletion tests/test_runner/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1665,7 +1665,6 @@ def test_dvclive_hook_model_file(tmp_path):
hook = DvcliveLoggerHook(model_file=osp.join(runner.work_dir, 'model.pth'))
runner.register_hook(hook)

loader = torch.utils.data.DataLoader(torch.ones((5, 2)))
loader = DataLoader(torch.ones((5, 2)))

runner.run([loader, loader], [('train', 1), ('val', 1)])
Expand All @@ -1675,6 +1674,16 @@ def test_dvclive_hook_model_file(tmp_path):
shutil.rmtree(runner.work_dir)


def test_dvclive_hook_pass_logger(tmp_path):
sys.modules['dvclive'] = MagicMock()
from dvclive import Live
logger = Live()

sys.modules['dvclive'] = MagicMock()
assert DvcliveLoggerHook().dvclive is not logger
assert DvcliveLoggerHook(dvclive=logger).dvclive is logger


def test_clearml_hook():
sys.modules['clearml'] = MagicMock()
runner = _build_demo_runner()
Expand Down

0 comments on commit c631779

Please sign in to comment.