Skip to content

Commit

Permalink
.item()
Browse files Browse the repository at this point in the history
  • Loading branch information
kexinyu committed Apr 5, 2020
1 parent a0bf956 commit d38e6fe
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
4 changes: 2 additions & 2 deletions apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ struct LAMBStage1Functor
const float epsilon,
adamMode_t mode,
const float decay,
float global_grad_norm,
float max_global_grad_norm)
const float global_grad_norm,
const float max_global_grad_norm)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
Expand Down
16 changes: 7 additions & 9 deletions apex/contrib/optimizers/fused_lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def __init__(self, params, lr=1e-3, bias_correction=True,

self.adam_w_mode = 1 if adam_w_mode else 0
self.set_grad_none = set_grad_none
print("debugging LAMB")

def zero_grad(self):
if self.set_grad_none:
Expand Down Expand Up @@ -116,23 +115,22 @@ def step(self, closure=None):
g_all_16.append(p.grad.data)
else:
raise RuntimeError('FusedLAMB only support fp16 and fp32.')
:q!

g_norm_32, g_norm_16 = 0.0, 0.0
# compute grad norm for two lists
if len(g_all_32) > 0:
g_norm_32, _ = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_32], False)
g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_32], False)[0].item()
if len(g_all_16) > 0:
g_norm_16, _ = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_16], False)
g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_16], False)[0].item()

# blend two grad norms to get global grad norm
global_grad_norm = math.sqrt(g_norm_32 * g_norm_32 + g_norm_16 * g_norm_16)
max_grad_norm = self.defaults['max_grad_norm']
print("====global_grad_norm:", global_grad_norm)
print("====max_grad_norm:", max_grad_norm)

for group in self.param_groups:
bias_correction = 1 if group['bias_correction'] else 0
Expand Down

0 comments on commit d38e6fe

Please sign in to comment.