-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_utils.py
36 lines (29 loc) · 1.17 KB
/
train_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from typing import *
from torch.optim import Optimizer
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
class ProdigyLRMonitor(LearningRateMonitor):
def _get_optimizer_stats(
self, optimizer: Optimizer, names: List[str]
) -> Dict[str, float]:
stats = {}
param_groups = optimizer.param_groups
use_betas = "betas" in optimizer.defaults
for pg, name in zip(param_groups, names):
lr = self._extract_lr(pg, name)
stats.update(lr)
momentum = self._extract_momentum(
param_group=pg,
name=name.replace(name, f"{name}-momentum"),
use_betas=use_betas,
)
stats.update(momentum)
weight_decay = self._extract_weight_decay(pg, f"{name}-weight_decay")
stats.update(weight_decay)
return stats
def _extract_lr(self, param_group: Dict[str, Any], name: str) -> Dict[str, Any]:
lr = param_group["lr"]
d = param_group.get("d", 1)
self.lrs[name].append(lr * d)
return {name: lr * d}