-
Notifications
You must be signed in to change notification settings - Fork 40
/
Copy pathpretrain_cpm_bee.py
389 lines (336 loc) · 15.2 KB
/
pretrain_cpm_bee.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
# coding=utf-8
# Copyright 2022 The OpenBMB team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import time
from typing import Any, Dict, List, Union
import torch
import bmtrain as bmt
import os
from cpm_live.arguments import get_args
from cpm_live.models import CPMBee, CPMBeeConfig
from cpm_live.tokenizers import CPMBeeTokenizer
from cpm_live.utils import allgather_objects, LogManager
from cpm_live.training_tasks.bee import MixedDataset
def get_tokenizer(args):
tokenizer = CPMBeeTokenizer()
return tokenizer
def get_model(args):
config = CPMBeeConfig.from_json_file(args.model_config)
model = CPMBee(config)
if args.load is not None:
bmt.load(model, args.load)
else:
bmt.init_parameters(model)
return model
def get_optimizer(args, model):
optimizer = bmt.optim.AdamOffloadOptimizer(
model.parameters(), weight_decay=args.weight_decay
)
if args.load is not None:
if os.path.exists(os.path.join(args.save, args.save_name + (".rank-%d.opt" % 0))):
# optimizer state exists
states = torch.load(
os.path.join(args.save, args.save_name + (".rank-%d.opt" % bmt.rank()))
)
optimizer.load_state_dict(states)
return optimizer
def get_learning_rate_scheduler(args, optimizer):
if args.lr_decay_iters is None:
args.lr_decay_iters = args.train_iters
lr_scheduler = bmt.lr_scheduler.Noam(
optimizer,
start_lr=args.lr,
warmup_iter=args.warmup_iters,
end_iter=args.lr_decay_iters,
num_iter=args.start_step,
)
return lr_scheduler
def setup_model_and_optimizer(args):
model = get_model(args)
tokenizer = get_tokenizer(args)
bmt.synchronize()
optimizer = get_optimizer(args, model)
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
bmt.synchronize()
optim_manager = bmt.optim.OptimManager(
loss_scale=args.loss_scale,
loss_scale_factor=2,
loss_scale_steps=512,
)
optim_manager.add_optimizer(optimizer, lr_scheduler)
return tokenizer, model, optimizer, lr_scheduler, optim_manager
def initialize():
os.environ["MASTER_PORT"] = str(int(os.environ["MASTER_PORT"]) + 2333)
args = get_args(pretrain=True)
bmt.init_distributed(seed=args.seed)
if args.save is not None:
os.makedirs(args.save, exist_ok=True)
return args
def see_memory(detail=False):
if detail:
res = torch.cuda.memory_summary()
else:
res = (
round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024), 2),
round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024), 2),
)
torch.cuda.reset_peak_memory_stats()
return res
def add_mem_time(info, mem_usage, tim_usage):
torch.cuda.synchronize()
mem_usage[info] = see_memory()
tim_usage[info] = time.time()
return mem_usage, tim_usage
class LossSpikeDetector:
def __init__(self, log_path: str) -> None:
self._last_loss: Dict[str, float] = {}
self._last_data: List[Any] = [None]
self._log_path = log_path
def update_data(self, data: Any):
self._last_data.append(data)
if len(self._last_data) > 2:
self._last_data = self._last_data[-2:]
def update_loss(self, iteration: int, loss_map: Dict[str, float]):
loss_spike_result = []
for task, loss in loss_map.items():
if task in self._last_loss:
if loss > self._last_loss[task] * 3:
# loss spike!
loss_spike_result.append(
{
"prev": self._last_loss[task],
"curr": loss,
"task": task,
}
)
self._last_loss[task] = float(loss)
if len(loss_spike_result) > 0:
self._write_log(iteration, self._last_data[-1], loss_spike_result)
def _write_log(self, iteration: int, data: Any, result: List[Dict[str, Any]]):
with open(self._log_path, "a", encoding="utf-8") as fp:
fp.write("=" * 20)
fp.write("\nloss spike at {}\n".format(iteration))
fp.write("{}\n".format(json.dumps(result, indent=4, ensure_ascii=False)))
fp.write("data: \n")
for d in data:
fp.write("{}\n".format(json.dumps(d, indent=4, ensure_ascii=False)))
fp.write("\n\n")
def pretrain(
args,
tokenizer: CPMBeeTokenizer,
model: CPMBee,
optimizer: bmt.optim.AdamOffloadOptimizer,
lr_scheduler: bmt.lr_scheduler.WarmupLRScheduler,
optim_manager: bmt.optim.OptimManager,
):
average_time = bmt.utils.AverageRecorder()
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100)
start_step = args.start_step
lsd = LossSpikeDetector("debug/spile.%d.log" % bmt.rank())
if args.tensorboard is not None and bmt.rank() == 0:
from torch.utils.tensorboard import SummaryWriter
import distutils.version # noqa: F401
if not os.path.exists(args.tensorboard):
os.makedirs(args.tensorboard)
writer = SummaryWriter(log_dir=args.tensorboard)
if args.log_dir is not None and bmt.rank() == 0:
log_mgr = LogManager(args.log_dir)
global_token_pass = 0.0
global_world_size = bmt.world_size()
dataloader = MixedDataset(
args.dataset, args.batch_size, args.max_length, tokenizer, max_depth=8
)
if os.path.exists(os.path.join(args.save, args.save_name + ("-%d.data.pt" % start_step))):
# load dataset states if exists
dataset_states = torch.load(
os.path.join(args.save, args.save_name + ("-%d.data.pt" % start_step))
)
missing = dataloader.load_state_dict(dataset_states)
if len(missing) > 0:
bmt.print_rank("Missing keys when loading dataset states: ", missing)
dataloader.start()
try:
for iteration, data in enumerate(dataloader):
iteration = iteration + start_step + 1
assert data["inputs"].shape[0] == args.batch_size
input_ids = torch.from_numpy(data["inputs"]).cuda().to(torch.int32)
input_ids_sub = torch.from_numpy(data["inputs_sub"]).cuda().to(torch.int32)
input_length = torch.from_numpy(data["length"]).cuda().to(torch.int32)
input_context = torch.from_numpy(data["context"]).cuda().bool()
input_sample_ids = torch.from_numpy(data["sample_ids"]).cuda().to(torch.int32)
input_num_segments = torch.from_numpy(data["num_segments"]).cuda().to(torch.int32)
input_segment_ids = torch.from_numpy(data["segment_ids"]).cuda().to(torch.int32)
input_segment_rel_offset = (
torch.from_numpy(data["segment_rel_offset"]).cuda().to(torch.int32)
)
input_segment_rel = torch.from_numpy(data["segment_rel"]).cuda().to(torch.int32)
input_span = torch.from_numpy(data["spans"]).cuda().to(torch.int32)
targets = torch.from_numpy(data["target"]).cuda().to(torch.int32)
ext_table_ids = torch.from_numpy(data["ext_ids"]).cuda().to(torch.int32)
ext_table_sub = torch.from_numpy(data["ext_sub"]).cuda().to(torch.int32)
task_ids = torch.from_numpy(data["task_ids"]).cuda().to(torch.int32)
task_names = data["task_names"]
lsd.update_data(data["raw_data"])
# ===========
optim_manager.zero_grad()
# torch.cuda.empty_cache()
mem_usage = {}
tim_usage = {}
mem_usage, tim_usage = add_mem_time("init", mem_usage, tim_usage)
# ===========
logits, _ = model(
input_ids,
input_ids_sub,
input_length,
input_context,
input_sample_ids,
input_num_segments,
input_segment_ids,
input_segment_rel_offset,
input_segment_rel,
input_span,
ext_table_ids,
ext_table_sub,
)
loss = loss_func(logits.view(-1, logits.size(-1)), targets.view(-1))
global_loss = bmt.sum_loss(loss).item()
mem_usage, tim_usage = add_mem_time("forward", mem_usage, tim_usage)
# ===========
optim_manager.backward(loss)
mem_usage, tim_usage = add_mem_time("backward", mem_usage, tim_usage)
# ===========
current_stream = torch.cuda.current_stream()
# some reduce ops of distributed parameter were launched on load stream
current_stream.wait_stream(bmt.config['load_stream'])
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, max_norm=1.0)
optim_manager.step()
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
# ==========
iteration_time = tim_usage["optim"] - tim_usage["init"]
average_time.record(iteration_time)
with torch.no_grad():
task_num = len(task_names)
targets_tmp = targets.expand(task_num, -1, -1)
task = torch.arange(task_num, dtype=torch.int32, device="cuda")[:, None, None]
targets_tmp = torch.where(
task_ids == task,
targets_tmp,
torch.scalar_tensor(-100, dtype=torch.int32, device="cuda"),
)
task_loss_map: Dict[str, float] = {}
for i in range(task_num):
task_loss = loss_func(
logits.view(-1, logits.size(-1)), targets_tmp[i, :].view(-1)
)
# global_task_loss = float(bmt.sum_loss(task_loss).item())
task_loss_map[task_names[i]] = task_loss.item()
gatherd_task_loss_map: List[Dict[str, float]] = allgather_objects(task_loss_map)
global_task_loss_map: Dict[str, Union[List[float], float]] = {}
for local_task_loss_map in gatherd_task_loss_map:
for task_name, task_loss in local_task_loss_map.items():
if task_name not in global_task_loss_map:
global_task_loss_map[task_name] = []
global_task_loss_map[task_name].append(task_loss)
task_loss_map = {}
for task_name in sorted(list(global_task_loss_map.keys())):
avg_loss = sum(global_task_loss_map[task_name]) / len(
global_task_loss_map[task_name]
)
task_loss_map[task_name] = avg_loss
local_total_rate = torch.Tensor([input_length.float().mean() / args.max_length]).cuda()
local_total_rate = bmt.sum_loss(local_total_rate).item()
global_token_pass += (
global_world_size * local_total_rate * args.max_length * args.batch_size
)
avg_time = average_time.value
lsd.update_loss(iteration, task_loss_map)
train_info = {
"time": tim_usage["init"],
"iteration": iteration,
"loss": global_loss,
"lr": lr_scheduler.current_lr,
"lr_scale": int(optim_manager.loss_scale),
"time_usage": tim_usage,
"mem_usage": mem_usage,
"avg_time": avg_time,
"token_max": local_total_rate,
"token_pass": global_token_pass,
"throughout": args.max_length * args.batch_size * local_total_rate / avg_time,
"grad_norm": grad_norm.item(),
"mask_max": ((targets >= 0).sum(-1).float().mean() / args.max_length).item(),
"num_gpus": global_world_size,
"task_loss": task_loss_map,
}
bmt.print_rank(
(
"| Iter: {:6d} | loss: {:.4f} | lr: {:.4e}, scale: {:10.4f} | time: {:.4f} |"
+ " token/max: {:.4f} | mask/max: {:.4f} | grad_norm: {:.4f}"
).format(
iteration,
global_loss,
lr_scheduler.current_lr,
int(optim_manager.loss_scale),
avg_time,
input_length.float().mean() / args.max_length,
(targets >= 0).sum(-1).float().mean() / args.max_length,
grad_norm,
)
)
bmt.print_rank(
"| "
+ " | ".join(
[
"{} loss: {:.4f}".format(task_name, loss)
for task_name, loss in task_loss_map.items()
]
)
)
if iteration % args.inspect_iters == 0:
model_inspect = bmt.inspect.inspect_model(model, "*")
bmt.print_rank(bmt.inspect.format_summary(model_inspect))
train_info["model_inspect"] = model_inspect
# write log here
if args.log_dir is not None and bmt.rank() == 0:
log_mgr.write(**train_info)
if args.tensorboard is not None and bmt.rank() == 0:
writer.add_scalar("Loss/train", global_loss, iteration)
writer.add_scalar("Optimizer/lr", lr_scheduler.current_lr, iteration)
writer.add_scalar("Optimizer/scale", optim_manager.loss_scale, iteration)
writer.add_scalar("Optimizer/grad_norm", grad_norm.item(), iteration)
for task_name, loss in task_loss_map.items():
writer.add_scalar("Loss/train/{}".format(task_name), loss, iteration)
if args.save is not None and iteration % args.save_iters == 0:
bmt.save(model, os.path.join(args.save, args.save_name + ("-%d.pt" % iteration)))
torch.save(
optimizer.state_dict(),
os.path.join(args.save, args.save_name + (".rank-%d.opt" % bmt.rank())),
)
all_states = dataloader.state_dict()
if bmt.rank() == 0:
# rank 0 writes the dataloader state
torch.save(
all_states,
os.path.join(args.save, args.save_name + ("-%d.data.pt" % iteration)),
)
del all_states
finally:
dataloader.close()
bmt.save(model, os.path.join(args.save, args.save_name + ".pt"))
def main():
args = initialize()
tokenizer, model, optimizer, lr_scheduler, optim_manager = setup_model_and_optimizer(args)
pretrain(args, tokenizer, model, optimizer, lr_scheduler, optim_manager)
if __name__ == "__main__":
main()