-
Notifications
You must be signed in to change notification settings - Fork 40
/
Copy pathoptimizer.py
170 lines (145 loc) · 5.55 KB
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import math
import torch
from torch.optim import AdamW
class MemoryEfficientAdamW(AdamW):
"""
Memory Efficient AdamW optimizer that keeps parameters and gradients on GPU
but optimizer states on CPU when enabled.
When disabled, behaves exactly like standard AdamW.
"""
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
pin_memory=True,
enabled=True,
):
super(MemoryEfficientAdamW, self).__init__(
params,
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad,
)
self.pin_memory = pin_memory
self.enabled = enabled
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step."""
if not self.enabled:
# Use the parent AdamW implementation when disabled
return super(MemoryEfficientAdamW, self).step(closure)
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
max_exp_avg_sqs = []
state_steps = []
beta1, beta2 = group["betas"]
for p in group["params"]:
if p.grad is None:
continue
params_with_grad.append(p)
grads.append(p.grad)
# Initialize state if needed
state = self.state[p]
if len(state) == 0:
state["step"] = 0
# Store optimizer states on CPU with pinned memory
device = "cpu"
pin_memory = self.pin_memory
dtype = torch.float32
state["exp_avg"] = torch.zeros_like(
p.data, device=device, pin_memory=pin_memory, dtype=dtype
)
state["exp_avg_sq"] = torch.zeros_like(
p.data, device=device, pin_memory=pin_memory, dtype=dtype
)
if group["amsgrad"]:
state["max_exp_avg_sq"] = torch.zeros_like(
p.data, device=device, pin_memory=pin_memory, dtype=dtype
)
# Get state values
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
if group["amsgrad"]:
max_exp_avg_sqs.append(state["max_exp_avg_sq"])
state["step"] += 1
state_steps.append(state["step"])
# Process all parameters in the group
self._memory_efficient_update(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=group["amsgrad"],
beta1=beta1,
beta2=beta2,
lr=group["lr"],
weight_decay=group["weight_decay"],
eps=group["eps"],
)
return loss
def _memory_efficient_update(
self,
params,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad,
beta1,
beta2,
lr,
weight_decay,
eps,
):
"""
Performs the AdamW parameter update on GPU with CPU-stored optimizer states.
Uses pinned memory for efficient CPU-to-GPU transfer of optimizer states.
"""
for i, param in enumerate(params):
grad = grads[i]
param_device = param.device
# Access optimizer states - they'll transfer efficiently due to pin_memory
exp_avg = exp_avgs[i].to(param_device, non_blocking=True)
exp_avg_sq = exp_avg_sqs[i].to(param_device, non_blocking=True)
step = state_steps[i]
# Decay the first and second moment running averages
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
if amsgrad:
# Access max_exp_avg_sq - transfers efficiently with pin_memory
max_exp_avg_sq = max_exp_avg_sqs[i].to(param_device, non_blocking=True)
# Maintains the maximum of all 2nd moment running avg. till now
torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max for normalizing running avg of gradient
denom = max_exp_avg_sq.sqrt().add_(eps)
# Store back to CPU
max_exp_avg_sqs[i].copy_(max_exp_avg_sq, non_blocking=True)
else:
denom = exp_avg_sq.sqrt().add_(eps)
bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step
step_size = lr * math.sqrt(bias_correction2) / bias_correction1
# Apply weight decay directly to the parameter (AdamW)
if weight_decay != 0:
param.mul_(1 - lr * weight_decay)
# Update parameters (directly on GPU)
param.addcdiv_(exp_avg, denom, value=-step_size)
# Store optimizer states back to CPU
exp_avgs[i].copy_(exp_avg, non_blocking=True)
exp_avg_sqs[i].copy_(exp_avg_sq, non_blocking=True)