@@ -31,51 +31,97 @@ def __init__(self,
31
31
] # expected momentum if no warming up is performed
32
32
33
33
def _set_momentum (self , runner , momentum_groups ):
34
- for param_group , mom in zip (runner .optimizer .param_groups ,
35
- momentum_groups ):
36
- if 'momentum' in param_group .keys ():
37
- param_group ['momentum' ] = mom
38
- elif 'betas' in param_group .keys ():
39
- param_group ['betas' ] = (mom , param_group ['betas' ][1 ])
34
+ if isinstance (runner .optimizer , dict ):
35
+ for k , optim in runner .optimizer .items ():
36
+ for param_group , mom in zip (optim .param_groups ,
37
+ momentum_groups [k ]):
38
+ if 'momentum' in param_group .keys ():
39
+ param_group ['momentum' ] = mom
40
+ elif 'betas' in param_group .keys ():
41
+ param_group ['betas' ] = (mom , param_group ['betas' ][1 ])
42
+ else :
43
+ for param_group , mom in zip (runner .optimizer .param_groups ,
44
+ momentum_groups ):
45
+ if 'momentum' in param_group .keys ():
46
+ param_group ['momentum' ] = mom
47
+ elif 'betas' in param_group .keys ():
48
+ param_group ['betas' ] = (mom , param_group ['betas' ][1 ])
40
49
41
50
def get_momentum (self , runner , base_momentum ):
42
51
raise NotImplementedError
43
52
44
53
def get_regular_momentum (self , runner ):
45
- return [
46
- self .get_momentum (runner , _base_momentum )
47
- for _base_momentum in self .base_momentum
48
- ]
54
+ if isinstance (runner .optimizer , dict ):
55
+ momentum_groups = {}
56
+ for k in runner .optimizer .keys ():
57
+ _momentum_group = [
58
+ self .get_momentum (runner , _base_momentum )
59
+ for _base_momentum in self .base_momentum [k ]
60
+ ]
61
+ momentum_groups .update ({k : _momentum_group })
62
+ return momentum_groups
63
+ else :
64
+ return [
65
+ self .get_momentum (runner , _base_momentum )
66
+ for _base_momentum in self .base_momentum
67
+ ]
49
68
50
69
def get_warmup_momentum (self , cur_iters ):
51
- if self .warmup == 'constant' :
52
- warmup_momentum = [
53
- _momentum / self .warmup_ratio
54
- for _momentum in self .regular_momentum
55
- ]
56
- elif self .warmup == 'linear' :
57
- k = (1 - cur_iters / self .warmup_iters ) * (1 - self .warmup_ratio )
58
- warmup_momentum = [
59
- _momentum / (1 - k ) for _momentum in self .regular_mom
60
- ]
61
- elif self .warmup == 'exp' :
62
- k = self .warmup_ratio ** (1 - cur_iters / self .warmup_iters )
63
- warmup_momentum = [_momentum / k for _momentum in self .regular_mom ]
64
- return warmup_momentum
70
+
71
+ def _get_warmup_momentum (cur_iters , regular_momentum ):
72
+ if self .warmup == 'constant' :
73
+ warmup_momentum = [
74
+ _momentum / self .warmup_ratio
75
+ for _momentum in self .regular_momentum
76
+ ]
77
+ elif self .warmup == 'linear' :
78
+ k = (1 - cur_iters / self .warmup_iters ) * (1 -
79
+ self .warmup_ratio )
80
+ warmup_momentum = [
81
+ _momentum / (1 - k ) for _momentum in self .regular_mom
82
+ ]
83
+ elif self .warmup == 'exp' :
84
+ k = self .warmup_ratio ** (1 - cur_iters / self .warmup_iters )
85
+ warmup_momentum = [
86
+ _momentum / k for _momentum in self .regular_mom
87
+ ]
88
+ return warmup_momentum
89
+
90
+ if isinstance (self .regular_momentum , dict ):
91
+ momentum_groups = {}
92
+ for key , regular_momentum in self .regular_momentum .items ():
93
+ momentum_groups [key ] = _get_warmup_momentum (
94
+ cur_iters , regular_momentum )
95
+ return momentum_groups
96
+ else :
97
+ return _get_warmup_momentum (cur_iters , self .regular_momentum )
65
98
66
99
def before_run (self , runner ):
67
100
# NOTE: when resuming from a checkpoint,
68
101
# if 'initial_momentum' is not saved,
69
102
# it will be set according to the optimizer params
70
- for group in runner .optimizer .param_groups :
71
- if 'momentum' in group .keys ():
72
- group .setdefault ('initial_momentum' , group ['momentum' ])
73
- else :
74
- group .setdefault ('initial_momentum' , group ['betas' ][0 ])
75
- self .base_momentum = [
76
- group ['initial_momentum' ]
77
- for group in runner .optimizer .param_groups
78
- ]
103
+ if isinstance (runner .optimizer , dict ):
104
+ self .base_momentum = {}
105
+ for k , optim in runner .optimizer .items ():
106
+ for group in optim .param_groups :
107
+ if 'momentum' in group .keys ():
108
+ group .setdefault ('initial_momentum' , group ['momentum' ])
109
+ else :
110
+ group .setdefault ('initial_momentum' , group ['betas' ][0 ])
111
+ _base_momentum = [
112
+ group ['initial_momentum' ] for group in optim .param_groups
113
+ ]
114
+ self .base_momentum .update ({k : _base_momentum })
115
+ else :
116
+ for group in runner .optimizer .param_groups :
117
+ if 'momentum' in group .keys ():
118
+ group .setdefault ('initial_momentum' , group ['momentum' ])
119
+ else :
120
+ group .setdefault ('initial_momentum' , group ['betas' ][0 ])
121
+ self .base_momentum = [
122
+ group ['initial_momentum' ]
123
+ for group in runner .optimizer .param_groups
124
+ ]
79
125
80
126
def before_train_epoch (self , runner ):
81
127
if not self .by_epoch :
@@ -383,9 +429,11 @@ def get_regular_momentum(self, runner):
383
429
if isinstance (runner .optimizer , dict ):
384
430
momentum_groups = {}
385
431
for k , optim in runner .optimizer .items ():
386
- for param_group in optim .param_groups :
387
- momentum_groups [k ].append (
388
- self .get_momentum (runner , param_group ))
432
+ _momentum_group = [
433
+ self .get_momentum (runner , param_group )
434
+ for param_group in optim .param_groups
435
+ ]
436
+ momentum_groups .update ({k : _momentum_group })
389
437
return momentum_groups
390
438
else :
391
439
momentum_groups = []
0 commit comments