-
Notifications
You must be signed in to change notification settings - Fork 9
/
grad_norm.py
35 lines (27 loc) · 880 Bytes
/
grad_norm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from typing import Dict
class GradNormTracker:
def __init__(self, win_size: int = 10):
self.clear()
self.win_size = win_size
def clear(self):
self.norms = {}
self.sums = {}
def get(self):
return {k: self.sums[k]/len(v) for k, v in self.norms.items() if v}
def add(self, name: str, param: torch.nn.Parameter):
if param.grad is None:
return
norm = param.grad.norm().item()
l = self.norms.get(name)
if l is None:
self.norms[name] = [norm]
self.sums[name] = norm
else:
if len(l) > self.win_size:
self.sums[name] -= l.pop(0)
l.append(norm)
self.sums[name] += norm
def add_dict(self, data: Dict[str, torch.Tensor]):
for k, v in data.items():
self.add(k, v)