Skip to content

Commit d525cfd

Browse files
authored
[Fix] Fix bug of lr updater hook (#907)
* [Fix] fix bug of lr update hook * [Fix] fix bug of lr update hook * [Fix] Fix bug of lr updater hook
1 parent 03a2e3a commit d525cfd

File tree

3 files changed

+287
-112
lines changed

3 files changed

+287
-112
lines changed

mmcv/runner/hooks/lr_updater.py

+20-9
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,26 @@ def get_regular_lr(self, runner):
8282
return [self.get_lr(runner, _base_lr) for _base_lr in self.base_lr]
8383

8484
def get_warmup_lr(self, cur_iters):
85-
if self.warmup == 'constant':
86-
warmup_lr = [_lr * self.warmup_ratio for _lr in self.regular_lr]
87-
elif self.warmup == 'linear':
88-
k = (1 - cur_iters / self.warmup_iters) * (1 - self.warmup_ratio)
89-
warmup_lr = [_lr * (1 - k) for _lr in self.regular_lr]
90-
elif self.warmup == 'exp':
91-
k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
92-
warmup_lr = [_lr * k for _lr in self.regular_lr]
93-
return warmup_lr
85+
86+
def _get_warmup_lr(cur_iters, regular_lr):
87+
if self.warmup == 'constant':
88+
warmup_lr = [_lr * self.warmup_ratio for _lr in regular_lr]
89+
elif self.warmup == 'linear':
90+
k = (1 - cur_iters / self.warmup_iters) * (1 -
91+
self.warmup_ratio)
92+
warmup_lr = [_lr * (1 - k) for _lr in regular_lr]
93+
elif self.warmup == 'exp':
94+
k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
95+
warmup_lr = [_lr * k for _lr in regular_lr]
96+
return warmup_lr
97+
98+
if isinstance(self.regular_lr, dict):
99+
lr_groups = {}
100+
for key, regular_lr in self.regular_lr.items():
101+
lr_groups[key] = _get_warmup_lr(cur_iters, regular_lr)
102+
return lr_groups
103+
else:
104+
return _get_warmup_lr(cur_iters, self.regular_lr)
94105

95106
def before_run(self, runner):
96107
# NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved,

mmcv/runner/hooks/momentum_updater.py

+84-36
Original file line numberDiff line numberDiff line change
@@ -31,51 +31,97 @@ def __init__(self,
3131
] # expected momentum if no warming up is performed
3232

3333
def _set_momentum(self, runner, momentum_groups):
34-
for param_group, mom in zip(runner.optimizer.param_groups,
35-
momentum_groups):
36-
if 'momentum' in param_group.keys():
37-
param_group['momentum'] = mom
38-
elif 'betas' in param_group.keys():
39-
param_group['betas'] = (mom, param_group['betas'][1])
34+
if isinstance(runner.optimizer, dict):
35+
for k, optim in runner.optimizer.items():
36+
for param_group, mom in zip(optim.param_groups,
37+
momentum_groups[k]):
38+
if 'momentum' in param_group.keys():
39+
param_group['momentum'] = mom
40+
elif 'betas' in param_group.keys():
41+
param_group['betas'] = (mom, param_group['betas'][1])
42+
else:
43+
for param_group, mom in zip(runner.optimizer.param_groups,
44+
momentum_groups):
45+
if 'momentum' in param_group.keys():
46+
param_group['momentum'] = mom
47+
elif 'betas' in param_group.keys():
48+
param_group['betas'] = (mom, param_group['betas'][1])
4049

4150
def get_momentum(self, runner, base_momentum):
4251
raise NotImplementedError
4352

4453
def get_regular_momentum(self, runner):
45-
return [
46-
self.get_momentum(runner, _base_momentum)
47-
for _base_momentum in self.base_momentum
48-
]
54+
if isinstance(runner.optimizer, dict):
55+
momentum_groups = {}
56+
for k in runner.optimizer.keys():
57+
_momentum_group = [
58+
self.get_momentum(runner, _base_momentum)
59+
for _base_momentum in self.base_momentum[k]
60+
]
61+
momentum_groups.update({k: _momentum_group})
62+
return momentum_groups
63+
else:
64+
return [
65+
self.get_momentum(runner, _base_momentum)
66+
for _base_momentum in self.base_momentum
67+
]
4968

5069
def get_warmup_momentum(self, cur_iters):
51-
if self.warmup == 'constant':
52-
warmup_momentum = [
53-
_momentum / self.warmup_ratio
54-
for _momentum in self.regular_momentum
55-
]
56-
elif self.warmup == 'linear':
57-
k = (1 - cur_iters / self.warmup_iters) * (1 - self.warmup_ratio)
58-
warmup_momentum = [
59-
_momentum / (1 - k) for _momentum in self.regular_mom
60-
]
61-
elif self.warmup == 'exp':
62-
k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
63-
warmup_momentum = [_momentum / k for _momentum in self.regular_mom]
64-
return warmup_momentum
70+
71+
def _get_warmup_momentum(cur_iters, regular_momentum):
72+
if self.warmup == 'constant':
73+
warmup_momentum = [
74+
_momentum / self.warmup_ratio
75+
for _momentum in self.regular_momentum
76+
]
77+
elif self.warmup == 'linear':
78+
k = (1 - cur_iters / self.warmup_iters) * (1 -
79+
self.warmup_ratio)
80+
warmup_momentum = [
81+
_momentum / (1 - k) for _momentum in self.regular_mom
82+
]
83+
elif self.warmup == 'exp':
84+
k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
85+
warmup_momentum = [
86+
_momentum / k for _momentum in self.regular_mom
87+
]
88+
return warmup_momentum
89+
90+
if isinstance(self.regular_momentum, dict):
91+
momentum_groups = {}
92+
for key, regular_momentum in self.regular_momentum.items():
93+
momentum_groups[key] = _get_warmup_momentum(
94+
cur_iters, regular_momentum)
95+
return momentum_groups
96+
else:
97+
return _get_warmup_momentum(cur_iters, self.regular_momentum)
6598

6699
def before_run(self, runner):
67100
# NOTE: when resuming from a checkpoint,
68101
# if 'initial_momentum' is not saved,
69102
# it will be set according to the optimizer params
70-
for group in runner.optimizer.param_groups:
71-
if 'momentum' in group.keys():
72-
group.setdefault('initial_momentum', group['momentum'])
73-
else:
74-
group.setdefault('initial_momentum', group['betas'][0])
75-
self.base_momentum = [
76-
group['initial_momentum']
77-
for group in runner.optimizer.param_groups
78-
]
103+
if isinstance(runner.optimizer, dict):
104+
self.base_momentum = {}
105+
for k, optim in runner.optimizer.items():
106+
for group in optim.param_groups:
107+
if 'momentum' in group.keys():
108+
group.setdefault('initial_momentum', group['momentum'])
109+
else:
110+
group.setdefault('initial_momentum', group['betas'][0])
111+
_base_momentum = [
112+
group['initial_momentum'] for group in optim.param_groups
113+
]
114+
self.base_momentum.update({k: _base_momentum})
115+
else:
116+
for group in runner.optimizer.param_groups:
117+
if 'momentum' in group.keys():
118+
group.setdefault('initial_momentum', group['momentum'])
119+
else:
120+
group.setdefault('initial_momentum', group['betas'][0])
121+
self.base_momentum = [
122+
group['initial_momentum']
123+
for group in runner.optimizer.param_groups
124+
]
79125

80126
def before_train_epoch(self, runner):
81127
if not self.by_epoch:
@@ -383,9 +429,11 @@ def get_regular_momentum(self, runner):
383429
if isinstance(runner.optimizer, dict):
384430
momentum_groups = {}
385431
for k, optim in runner.optimizer.items():
386-
for param_group in optim.param_groups:
387-
momentum_groups[k].append(
388-
self.get_momentum(runner, param_group))
432+
_momentum_group = [
433+
self.get_momentum(runner, param_group)
434+
for param_group in optim.param_groups
435+
]
436+
momentum_groups.update({k: _momentum_group})
389437
return momentum_groups
390438
else:
391439
momentum_groups = []

0 commit comments

Comments
 (0)