Skip to content

Commit

Permalink
[Refactor]: Change scheduler to param_scheduler (open-mmlab#121)
Browse files Browse the repository at this point in the history
* [Refactor]: Change scheduler to param_scheduler

* [Fix]: Fix UT of param scheduler hook

Co-authored-by: Your <you@example.com>
  • Loading branch information
YuanLiuuuuuu and Your authored Mar 12, 2022
1 parent 61fecab commit 755f8b5
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions mmengine/hooks/param_scheduler_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def after_train_iter(self,
In order to keep this interface consistent with other hooks, we
keep ``data_batch`` here. Defaults to None.
"""
for scheduler in runner.schedulers:
for scheduler in runner.param_schedulers: # type: ignore
if not scheduler.by_epoch:
scheduler.step()

Expand All @@ -41,6 +41,6 @@ def after_train_epoch(self, runner) -> None:
Args:
runner (Runner): The runner of the training process.
"""
for scheduler in runner.schedulers:
for scheduler in runner.param_schedulers: # type: ignore
if scheduler.by_epoch:
scheduler.step()
4 changes: 2 additions & 2 deletions tests/test_hook/test_param_scheduler_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_after_iter(self):
scheduler = Mock()
scheduler.step = Mock()
scheduler.by_epoch = False
Runner.schedulers = [scheduler]
Runner.param_schedulers = [scheduler]
Hook.after_train_iter(Runner)
scheduler.step.assert_called()

Expand All @@ -22,6 +22,6 @@ def test_after_epoch(self):
scheduler = Mock()
scheduler.step = Mock()
scheduler.by_epoch = True
Runner.schedulers = [scheduler]
Runner.param_schedulers = [scheduler]
Hook.after_train_epoch(Runner)
scheduler.step.assert_called()

0 comments on commit 755f8b5

Please sign in to comment.