-
Notifications
You must be signed in to change notification settings - Fork 55
/
lr_policy.py
22 lines (18 loc) · 822 Bytes
/
lr_policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import numpy as np
class StepLR(object):
def __init__(self, optimizer, step_size=1000, max_iter=10000):
self.optimizer = optimizer
self.max_iter = max_iter
self.step_size = step_size
self.last_iter = -1
self.base_lrs = list(map(lambda group: group['lr'], optimizer.param_groups))
def get_lr(self):
return self.optimizer.param_groups[0]['lr']
def step(self, last_iter=None):
if last_iter is not None:
self.last_iter = last_iter
if self.last_iter + 1 == self.max_iter:
self.last_iter = -1
self.last_iter = (self.last_iter + 1) % self.max_iter
for ids, param_group in enumerate(self.optimizer.param_groups):
param_group['lr'] = self.base_lrs[ids] * 0.1 ** ( self.last_iter // self.step_size )