-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsgld.py
107 lines (90 loc) · 4.5 KB
/
sgld.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
Entropy-SGD TensorFlow implementation
Original paper: arXiv 1611.01838
Justin Tan 2017
"""
import tensorflow as tf
import numpy as np
from tensorflow.python.framework import ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.training import optimizer
from tensorflow.python.training import training_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import state_ops
class local_entropy_sgld(optimizer.Optimizer):
def __init__(self, eta_prime, epsilon, gamma, alpha, momentum, L,
sgld_global_step, use_locking=False, name='le_sgld'):
# Run inner loop Langevin dynamics
super(local_entropy_sgld, self).__init__(use_locking, name)
self._lr_prime = eta_prime
self._epsilon = epsilon
self._gamma = gamma
self._momentum = momentum
self._alpha = alpha
self._L = L
self.sgld_global_step = sgld_global_step
# Parameter tensors
self._lr_prime_t = None
self._epsilon_t = None
self._gamma_t = None
self._momentum_t = None
self._alpha_t = None
self._L_t = None
self._sgld_gs_t = None
def _prepare(self):
self._lr_prime_t = ops.convert_to_tensor(self._lr_prime,
name="learning_rate_prime")
self._epsilon_t = ops.convert_to_tensor(self._epsilon,
name="epsilon")
self._gamma_t = ops.convert_to_tensor(self._gamma,
name="gamma")
self._momentum_t = ops.convert_to_tensor(self._momentum,
name="momentum")
self._alpha_t = ops.convert_to_tensor(self._alpha,
name="alpha")
self._L_t = ops.convert_to_tensor(self._L,
name="L")
self._sgld_gs_t = ops.convert_to_tensor(self.sgld_global_step,
name="sgld_global_step")
def _create_slots(self, var_list):
# Manage variables that accumulate updates
# Creates slots for x', the expectation μ = <x'> and current weights
for v in var_list:
wc = self._zeros_slot(v, "wc", self._name)
xp = self._zeros_slot(v, "xp", self._name)
mu = self._zeros_slot(v, "mu", self._name)
mv = self._zeros_slot(v, "mv", self._name)
def _apply_dense(self, grad, var):
# Updates dummy weights during SGLD
# Reassign to original weights upon completion of inner loop
lr_prime_t = math_ops.cast(self._lr_prime_t, var.dtype.base_dtype)
epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
gamma_t = math_ops.cast(self._gamma_t, var.dtype.base_dtype)
momentum_t = math_ops.cast(self._momentum_t, var.dtype.base_dtype)
alpha_t = math_ops.cast(self._alpha_t, var.dtype.base_dtype)
wc = self.get_slot(var, 'wc')
xp = self.get_slot(var, 'xp')
mu = self.get_slot(var, 'mu')
mv = self.get_slot(var, 'mv')
wc_t = tf.cond(tf.logical_not(tf.cast(tf.mod(self.sgld_global_step, self._L_t), tf.bool)),
lambda: wc.assign(var),
lambda: wc)
eta = tf.random_normal(shape=var.get_shape())
eta_t = math_ops.cast(eta, var.dtype.base_dtype)
# update = -lr_prime_t*(grad-gamma_t*(wc-var)) + tf.sqrt(lr_prime)*epsilon_t*eta_t
mv_t = mv.assign(momentum_t*mv + grad)
# Nesterov's momentum enabled by default
if self._momentum > 0:
xp_t = xp.assign(var-lr_prime_t*(grad-gamma_t*(wc-var))+tf.sqrt(lr_prime_t)*epsilon_t*eta_t-lr_prime_t*momentum_t*mv_t)
var_update = state_ops.assign_sub(var,
lr_prime_t*(grad-gamma_t*(wc-var))-tf.sqrt(lr_prime_t)*epsilon_t*eta_t+lr_prime_t*momentum_t*mv_t)
else:
xp_t = xp.assign(var-lr_prime_t*(grad-gamma_t*(wc-var))+tf.sqrt(lr_prime_t)*epsilon_t*eta_t)
var_update = state_ops.assign_sub(var,
lr_prime_t*(grad-gamma_t*(wc-var))-tf.sqrt(lr_prime_t)*epsilon_t*eta_t)
mu_t = mu.assign((1.0-alpha_t)*mu + alpha_t*xp)
return control_flow_ops.group(*[var_update, mv_t, wc_t, xp_t, mu_t])
def _apply_sparse(self, grad, var_list):
raise NotImplementedError("Optimizer does not yet support sparse gradient updates.")