-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathoptimizer.py
50 lines (40 loc) · 1.94 KB
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from __future__ import absolute_import
import tensorflow as tf
import tools
slim = tf.contrib.slim
class PolyOptimizer(object):
def __init__(self, training_params):
self.base_lr = training_params["base_lr"]
self.warmup_steps = training_params["warmup_iter"]
self.warmup_learning_rate = training_params["warmup_start_lr"]
self.power = 2.0
self.momentum = 0.9
def optimize(self,loss, training,total_steps):
"""
Momentum optimizer using a polynomial decay and a warmup phas to match this
prototxt: https://github.com/amirgholami/SqueezeNext/blob/master/1.0-SqNxt-23/solver.prototxt
:param loss:
Loss value scalar
:param training:
Whether or not the model is training used to prevent updating moving mean of batch norm during eval
:param total_steps:
Total steps of the model used in the polynomial decay
:return:
Train op created with slim.learning.create_train_op
"""
with tf.name_scope("PolyOptimizer"):
global_step = tools.get_or_create_global_step()
learning_rate_schedule = tf.train.polynomial_decay(
learning_rate=self.base_lr,
global_step=global_step,
decay_steps=total_steps,
power=self.power
)
learning_rate_schedule = tools.warmup_phase(learning_rate_schedule,self.base_lr, self.warmup_steps,self.warmup_learning_rate)
tf.summary.scalar("learning_rate",learning_rate_schedule)
optimizer = tf.train.MomentumOptimizer(learning_rate_schedule,self.momentum)
return slim.learning.create_train_op(loss,
optimizer,
global_step=global_step,
aggregation_method=tf.AggregationMethod.ADD_N,
update_ops=None if training else [])