-
Notifications
You must be signed in to change notification settings - Fork 1
/
nanogpt.patch
280 lines (269 loc) · 13.5 KB
/
nanogpt.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
diff --git a/config/train_shakespeare_char.py b/config/train_shakespeare_char.py
index 41c81df..df9b757 100644
--- a/config/train_shakespeare_char.py
+++ b/config/train_shakespeare_char.py
@@ -3,8 +3,8 @@
out_dir = 'out-shakespeare-char'
eval_interval = 250 # keep frequent because we'll overfit
-eval_iters = 200
-log_interval = 10 # don't print too too often
+eval_iters = 2
+log_interval = 1 # don't print too too often
# we expect to overfit on this small dataset, so only save when val improves
always_save_checkpoint = False
@@ -22,7 +22,7 @@ block_size = 256 # context of up to 256 previous characters
n_layer = 6
n_head = 6
n_embd = 384
-dropout = 0.2
+dropout = 0.0
learning_rate = 1e-3 # with baby networks can afford to go a bit higher
max_iters = 5000
diff --git a/model.py b/model.py
index c698f8b..014630a 100644
--- a/model.py
+++ b/model.py
@@ -184,7 +184,8 @@ class GPT(nn.Module):
if targets is not None:
# if we are given some desired targets also calculate the loss
logits = self.lm_head(x)
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
+ flat_logits = logits.view(-1, logits.size(-1))
+ loss = F.cross_entropy(flat_logits, targets.view(-1), ignore_index=-1, reduction='sum') / flat_logits.size(0)
else:
# inference-time mini-optimization: only forward the lm_head on the very last position
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
diff --git a/train.py b/train.py
index a482ab7..b4463dc 100644
--- a/train.py
+++ b/train.py
@@ -20,7 +20,7 @@ import os
import time
import math
import pickle
-from contextlib import nullcontext
+from contextlib import nullcontext, contextmanager
import numpy as np
import torch
@@ -28,6 +28,12 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
from model import GPTConfig, GPT
+from single_controller import DTensor, active_sharding, Manager, LocalWorker, to_local, WorkerMesh, Sharding
+
+manager = Manager()
+workers = WorkerMesh([manager.create_worker(local=True) for i in range(2)])
+batch_sharding = Sharding(workers, 0)
+replicated_sharding = Sharding(workers, 'r')
# -----------------------------------------------------------------------------
# default config values designed to train a gpt2 (124M) on OpenWebText
@@ -35,7 +41,7 @@ from model import GPTConfig, GPT
out_dir = 'out'
eval_interval = 2000
log_interval = 1
-eval_iters = 200
+eval_iters = 2
eval_only = False # if True, script exits right after the first eval
always_save_checkpoint = True # if True, always save a checkpoint after each eval
init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*'
@@ -71,7 +77,7 @@ backend = 'nccl' # 'nccl', 'gloo', etc.
# system
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
-compile = True # use PyTorch 2.0 to compile the model to be faster
+compile = False # use PyTorch 2.0 to compile the model to be faster
# -----------------------------------------------------------------------------
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open('configurator.py').read()) # overrides from command line or config file
@@ -117,10 +123,13 @@ train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mod
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
def get_batch(split):
data = train_data if split == 'train' else val_data
- ix = torch.randint(len(data) - block_size, (batch_size,))
- x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
- y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
- if device_type == 'cuda':
+ with active_sharding(None):
+ ix = torch.randint(len(data) - block_size, (batch_size,))
+ x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
+ y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
+ x = DTensor.to_remote(x, sharding=batch_sharding)
+ y = DTensor.to_remote(y, sharding=batch_sharding)
+ if False and device_type == 'cuda':
# pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
else:
@@ -187,7 +196,17 @@ elif init_from.startswith('gpt2'):
if block_size < model.config.block_size:
model.crop_block_size(block_size)
model_args['block_size'] = block_size # so that the checkpoint will have the right value
-model.to(device)
+
+def dtensorify(module):
+ for name, param in module.named_parameters(recurse=False):
+ setattr(module, name, torch.nn.Parameter(DTensor.to_remote(param, sharding=replicated_sharding).to(device)))
+ for name, param in module.named_buffers(recurse=False):
+ setattr(module, name, DTensor.to_remote(param, sharding=replicated_sharding).to(device))
+
+model.apply(dtensorify)
+# calling 'to' on model causes the fake tensors to get moved to cuda,
+# but the underlying real tensors stay on cpu for some
+#model.to(device)
# initialize a GradScaler. If enabled=False scaler is a no-op
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
@@ -219,7 +238,7 @@ def estimate_loss():
X, Y = get_batch(split)
with ctx:
logits, loss = model(X, Y)
- losses[k] = loss.item()
+ losses[k] = loss.to_sharding_('r')
out[split] = losses.mean()
model.train()
return out
@@ -250,65 +269,76 @@ local_iter_num = 0 # number of iterations in the lifetime of this process
raw_model = model.module if ddp else model # unwrap DDP container if needed
running_mfu = -1.0
while True:
+ with active_sharding(replicated_sharding):
+ # determine and set the learning rate for this iteration
+ lr = get_lr(iter_num) if decay_lr else learning_rate
+ for param_group in optimizer.param_groups:
+ param_group['lr'] = lr
- # determine and set the learning rate for this iteration
- lr = get_lr(iter_num) if decay_lr else learning_rate
- for param_group in optimizer.param_groups:
- param_group['lr'] = lr
-
- # evaluate the loss on train/val sets and write checkpoints
- if iter_num % eval_interval == 0 and master_process:
- losses = estimate_loss()
- print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
- if wandb_log:
- wandb.log({
- "iter": iter_num,
- "train/loss": losses['train'],
- "val/loss": losses['val'],
- "lr": lr,
- "mfu": running_mfu*100, # convert to percentage
- })
- if losses['val'] < best_val_loss or always_save_checkpoint:
- best_val_loss = losses['val']
- if iter_num > 0:
- checkpoint = {
- 'model': raw_model.state_dict(),
- 'optimizer': optimizer.state_dict(),
- 'model_args': model_args,
- 'iter_num': iter_num,
- 'best_val_loss': best_val_loss,
- 'config': config,
- }
- print(f"saving checkpoint to {out_dir}")
- torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
- if iter_num == 0 and eval_only:
- break
+ # evaluate the loss on train/val sets and write checkpoints
+ if iter_num % eval_interval == 0 and master_process:
+ def report(losses):
+ global best_val_loss
+ print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
+ if wandb_log:
+ wandb.log({
+ "iter": iter_num,
+ "train/loss": losses['train'],
+ "val/loss": losses['val'],
+ "lr": lr,
+ "mfu": running_mfu*100, # convert to percentage
+ })
+ if losses['val'] < best_val_loss or always_save_checkpoint:
+ best_val_loss = losses['val']
+ # TODO: thread safe ability to issue commands
+ if iter_num > 0 and False:
+ checkpoint = {
+ 'model': raw_model.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'model_args': model_args,
+ 'iter_num': iter_num,
+ 'best_val_loss': best_val_loss,
+ 'config': config,
+ }
+ print(f"saving checkpoint to {out_dir}")
+ torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
+ losses = estimate_loss()
+ to_local(losses).then(report)
- # forward backward update, with optional gradient accumulation to simulate larger batch size
- # and using the GradScaler if data type is float16
- for micro_step in range(gradient_accumulation_steps):
- if ddp:
- # in DDP training we only need to sync gradients at the last micro step.
- # the official way to do this is with model.no_sync() context manager, but
- # I really dislike that this bloats the code and forces us to repeat code
- # looking at the source of that context manager, it just toggles this variable
- model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
- with ctx:
- logits, loss = model(X, Y)
- loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
- # immediately async prefetch next batch while model is doing the forward pass on the GPU
- X, Y = get_batch('train')
- # backward pass, with gradient scaling if training in fp16
- scaler.scale(loss).backward()
- # clip the gradient
- if grad_clip != 0.0:
- scaler.unscale_(optimizer)
- torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
- # step the optimizer and scaler if training in fp16
- scaler.step(optimizer)
- scaler.update()
- # flush the gradients as soon as we can, no need for this memory anymore
- optimizer.zero_grad(set_to_none=True)
+ if iter_num == 0 and eval_only:
+ break
+
+ # forward backward update, with optional gradient accumulation to simulate larger batch size
+ # and using the GradScaler if data type is float16
+ for micro_step in range(gradient_accumulation_steps):
+ if ddp:
+ # in DDP training we only need to sync gradients at the last micro step.
+ # the official way to do this is with model.no_sync() context manager, but
+ # I really dislike that this bloats the code and forces us to repeat code
+ # looking at the source of that context manager, it just toggles this variable
+ model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
+ with ctx:
+ logits, loss = model(X, Y)
+ loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
+ # immediately async prefetch next batch while model is doing the forward pass on the GPU
+ X, Y = get_batch('train')
+ # backward pass, with gradient scaling if training in fp16
+ scaler.scale(loss).backward()
+
+ for p in model.parameters():
+ p.grad.to_sharding_('r')
+
+ # clip the gradient
+ if grad_clip != 0.0:
+ scaler.unscale_(optimizer)
+ for p in model.parameters():
+ assert p.sharding.sharding == ['r']
+ torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
+ # step the optimizer and scaler if training in fp16
+ scaler.step(optimizer)
+ scaler.update()
+ # flush the gradients as soon as we can, no need for this memory anymore
+ optimizer.zero_grad(set_to_none=True)
# timing and logging
t1 = time.time()
@@ -317,11 +347,15 @@ while True:
if iter_num % log_interval == 0 and master_process:
# get loss as float. note: this is a CPU-GPU sync point
# scale up to undo the division above, approximating the true total loss (exact would have been a sum)
- lossf = loss.item() * gradient_accumulation_steps
if local_iter_num >= 5: # let the training loop settle a bit
mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
- print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
+
+ def report(loss):
+ lossf = loss.item() * gradient_accumulation_steps
+ print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
+ loss.to_sharding_('r').to_local().then(report)
+
iter_num += 1
local_iter_num += 1