-
Notifications
You must be signed in to change notification settings - Fork 4.2k
/
stage3.py
2817 lines (2245 loc) · 129 KB
/
stage3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import sys
import gc
import collections
from typing import Deque, Dict, Tuple
from deepspeed import comm as dist
from deepspeed.utils import groups
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from deepspeed.runtime.base_optimizer import ZeROOptimizer
from deepspeed.utils import logger
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce
from deepspeed.runtime.utils import inf, get_global_norm, is_model_parallel_parameter, get_only_unique_item
from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload
from deepspeed.runtime.zero.utils import apply_to_tensors_only
from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus
from deepspeed.runtime.swap_tensor.optimizer_utils import OptimizerSwapper
from deepspeed.runtime.swap_tensor.partitioned_optimizer_swapper import PartitionedOptimizerSwapper
from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper
from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FP32_FLAT_GROUPS, PARTITION_COUNT, ZERO_STAGE, LOSS_SCALER
from deepspeed.accelerator import get_accelerator
from deepspeed.utils import z3_leaf_parameter
# Toggle this to true to enable correctness test
# with gradient partitioning and without
pg_correctness_test = False
OPTIMIZER_SWAP_IN_STATE_TIMER = 'optimizer_swap_in_state'
INIT_OPTIMIZER_TIMER = 'init_optimizer_state'
OPTIMIZER_SWAP_OUT_STATE_TIMER = 'optimizer_swap_out_state'
OPTIMIZER_STEP_TIMER = 'optimizer_step'
def print_rank_0(message, debug=False, force=False):
rank = dist.get_rank()
if rank == 0 and (debug or force):
logger.info(message)
# other variations
# - print for all ranks w/o interleaving
# printflock(f"[{rank}] {message}")
# - print to log file per rank
# log_rank_file(rank, message)
def input(msg):
return
def isclose(a, b, rtol=1e-09, atol=0.0):
return abs(a - b) <= max(rtol * max(abs(a), abs(b)), atol)
def lcm(x, y):
from fractions import gcd # or can import gcd from `math` in Python 3
return x * y // gcd(x, y)
def move_to_cpu(tensor_list):
for tensor in tensor_list:
tensor.data = tensor.data.cpu()
INITIAL_MICRO_STEP_ID = -1
class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
"""
DeepSpeedZeroOptimizer designed to reduce the memory footprint
required for training large deep learning models.
For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models
https://arxiv.org/abs/1910.02054
For usage examples, refer to TODO: DeepSpeed Tutorial
"""
def __init__(
self,
module,
init_optimizer,
timers,
ds_config,
static_loss_scale=1.0,
dynamic_loss_scale=False,
dynamic_loss_args=None,
verbose=True,
contiguous_gradients=True,
reduce_bucket_size=500000000,
prefetch_bucket_size=50000000,
max_reuse_distance=1000000000,
max_live_parameters=1000000000,
param_persistence_threshold=100000,
model_persistence_threshold=sys.maxsize,
dp_process_group=None,
reduce_scatter=True,
overlap_comm=False,
offload_optimizer_config=None,
offload_param_config=None,
sub_group_size=1000000000000,
offload_ratio=0.0,
mpu=None,
clip_grad=0.0,
gradient_accumulation_dtype=torch.float32,
communication_data_type=torch.float16,
postscale_gradients=True,
gradient_predivide_factor=1.0,
gradient_accumulation_steps=1,
elastic_checkpoint=False,
aio_config=None,
all2all_process_group=None,
zero_hpz_partition_size=1,
zero_quantized_weights=False,
zero_quantized_nontrainable_weights=False,
):
see_memory_usage("Stage 3 initialize beginning", force=True)
print_rank_0(f"initialized {__class__.__name__} with args: {locals()}", force=False)
if dist.get_rank() == 0:
logger.info(f"Reduce bucket size {reduce_bucket_size}")
logger.info(f"Prefetch bucket size {prefetch_bucket_size}")
# The fused optimizer does all the work. We need this layer for two reason:
# 1. maintain same user API from apex.fp16_utils
# 2. keep common stuff here in case we need to add ne552w fused optimizer later
# differences from apex.fp16_utils:
# - assume all model params in fp16
# - assume all params requires grad
# - flat by groups, not keeping state. TODO: remove state explicitly?
# - master grad and unflat master weight never exist. TODO: a way to save out unflat master?
if not get_accelerator().is_available():
raise SystemError("Cannot use fp16 without accelerator.")
self.optimizer = init_optimizer
# Use torch (un)flatten ops
self.flatten = _flatten_dense_tensors
self.unflatten = _unflatten_dense_tensors
self.dtype = self.optimizer.param_groups[0]['params'][0].dtype
self.gradient_accumulation_dtype = gradient_accumulation_dtype
self._global_grad_norm = 0.
self.custom_loss_scaler = False
self.external_loss_scale = None
self.optimizer_swapper = None
self.swap_optimizer = False
self.offload_optimizer = False
self.offload_optimizer_pin_memory = False
self.offload_optimizer_fast_init = False
self.offload_param = False
self.offload_param_pin_memory = False
self.params_in_nvme_and_cpu = False
self.max_params_in_cpu = 0
self.partial_offload = offload_ratio
#num of ranks in a ZeRO param partitioning group
self.zero_hpz_partition_size = zero_hpz_partition_size
zero_param_parallel_group = groups._get_zero_param_intra_parallel_group()
print_rank_0(
f"ZeRO Stage 3 param partitioning group {self.zero_hpz_partition_size} {zero_param_parallel_group}",
force=False)
if self.zero_hpz_partition_size > 1 and zero_param_parallel_group is None:
self._set_zero_group_parallelism()
zero_param_parallel_group = groups._get_zero_param_intra_parallel_group()
self.parameter_offload = self.initialize_ds_offload(
module=module,
timers=timers,
ds_config=ds_config,
overlap_comm=overlap_comm,
prefetch_bucket_size=prefetch_bucket_size,
max_reuse_distance=max_reuse_distance,
max_live_parameters=max_live_parameters,
param_persistence_threshold=param_persistence_threshold,
model_persistence_threshold=model_persistence_threshold,
dp_process_group=dp_process_group,
offload_param_config=offload_param_config,
mpu=mpu,
zero_param_parallel_group=zero_param_parallel_group,
zero_quantized_weights=zero_quantized_weights,
zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights)
self.persistent_parameters = self.parameter_offload.persistent_parameters
self._configure_offloading(offload_optimizer_config, offload_param_config)
# backup fused_adam optimizer init
if self.offload_optimizer and self.partial_offload != 1.0:
backup_gpu_tensor = torch.randn(1, device=get_accelerator().device_name()).to(self.dtype)
backup_gpu_param = torch.nn.Parameter(backup_gpu_tensor)
assert type(init_optimizer) == DeepSpeedCPUAdam, 'Hybrid Optimizer Only Supports DeepSpeedCPUAdam'
self.backup_optimizer = torch.optim.AdamW([backup_gpu_param],
lr=self.optimizer.param_groups[0]["lr"],
betas=self.optimizer.param_groups[0]["betas"],
eps=self.optimizer.param_groups[0]["eps"],
weight_decay=self.optimizer.param_groups[0]["weight_decay"],
amsgrad=self.optimizer.param_groups[0]["amsgrad"])
# Multiple param_groups configs for back-up optimizer
if len(self.optimizer.param_groups) > 1:
for i in range(1, len(self.optimizer.param_groups)):
self.backup_optimizer.add_param_group(self.optimizer.param_groups[i])
self.module = module
self.elastic_checkpoint = elastic_checkpoint
self.inf_or_nan_tracker: Tensor = torch.zeros(1,
dtype=torch.bool,
device=get_accelerator().current_device_name(),
requires_grad=False)
self.deepspeed_adam_offload = (self.offload_optimizer and type(init_optimizer) == DeepSpeedCPUAdam)
self.device = get_accelerator().current_device_name() if not self.offload_optimizer else OffloadDeviceEnum.cpu
### streams used for overlapping computation with communication
self.reduce_and_partition_stream = None if get_accelerator().is_synchronized_device() else get_accelerator(
).Stream() if overlap_comm else get_accelerator().default_stream()
############################################################################
self.n_caching_allocator_flushes = 0
#-------------Stage 3 Setup-------------------#
self.timers = timers
self.all2all_process_group = all2all_process_group
self.reduce_scatter = reduce_scatter
self.dp_process_group = self.parameter_offload.dp_process_group
self.sequence_parallel_size = groups._get_sequence_parallel_world_size()
self.all2all_process_group = all2all_process_group
self.zero_quantized_nontrainable_weights = zero_quantized_nontrainable_weights
self.partition_count = dist.get_world_size(group=self.dp_process_group)
if mpu is None:
self.model_parallel_group = None
self.model_parallel_rank = 0
else:
self.model_parallel_group = mpu.get_model_parallel_group()
self.model_parallel_rank = mpu.get_model_parallel_rank()
self.overflow = False
self.clip_grad = clip_grad
self.communication_data_type = communication_data_type
self.gradient_predivide_factor = gradient_predivide_factor
self.postscale_gradients = postscale_gradients
self.gradient_accumulation_steps = gradient_accumulation_steps
self.micro_step_id = 0
self.reduce_bucket_size = int(reduce_bucket_size)
if self.all2all_process_group is not None:
assert self.all2all_process_group is not None and self.reduce_scatter == True, "when enable all_to_all_reduce, reduce_scatter should also be enabled for data type checks."
if self.reduce_scatter:
valid_reduce_scatter_dtypes = (torch.float16, torch.bfloat16, torch.float32)
assert self.communication_data_type in valid_reduce_scatter_dtypes, f"ZeRO-3 supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'"
assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with ZeRO-3 with reduce scatter enabled"
assert self.postscale_gradients, "pre-scale gradients is not yet supported with ZeRO-3 with reduce scatter enabled"
# Holds the mode parameter
# The param.data may not hold any meaningful data
# when param's status is NOT_AVAILABLE or IN_FLGHT
self.fp16_groups = []
# Hold partitioned parameters
self.fp16_partitioned_groups = []
# Holds a fused and flattened copy of the parameters
self.fp16_partitioned_groups_flat = []
self.fp16_partitioned_groups_flat_numel = []
self.fp16_partitioned_groups_flat_id = []
#defragmented pinned memory
self.param_groups_fp16_flat_cpu_memory = []
#a single 32-bit partition of the parallel partitioned parameters
#that this process will update
self.fp32_partitioned_groups_flat = []
self.next_swappable_fp32_partitioned_groups = []
# number of elements per partition in each group
self.partition_size = []
self.all_reduce_print = False
self.prefetch_elements = int(prefetch_bucket_size)
self.contiguous_gradients = contiguous_gradients
# padding on each partition for alignment purposes
self.groups_padding = []
self.sub_group_size = sub_group_size
self.sub_group_to_group_id = {}
# Trainable parameters
self.trainable_param_groups = self._get_trainable_parameter_groups()
see_memory_usage("Before creating fp16 partitions", force=True)
self._create_fp16_partitions_with_defragmentation(self.trainable_param_groups)
num_fp16_subgroups = len(self.fp16_partitioned_groups_flat)
see_memory_usage(f"After creating fp16 partitions: {num_fp16_subgroups}", force=True)
# Optimizer tensor swapping
if self.swap_optimizer:
self._configure_tensor_swapping(offload_optimizer_config, aio_config)
self.is_gradient_accumulation_boundary: bool = True
self.param_reduce_events: Deque[get_accelerator().Event] = collections.deque()
# TODO. make this configurable via JSON
self.max_param_reduce_events: int = 2
self.param_dict = {}
# map between param_id and bool to specify if a param is in this partition
self.is_param_in_current_partition = {}
self.extra_large_param_to_reduce = None
self.grads_in_ipg_bucket = []
self.params_in_ipg_bucket = []
self.params_already_reduced = {}
self.is_gradient_accumulation_boundary = True
self._release_ipg_buffers()
self.previous_reduced_grads = None
# model parameter traversal-based param id that's stable across runs
for params_group in self.fp16_groups:
for param in params_group:
param_id = self.get_param_id(param)
self.param_dict[param_id] = param
self.params_already_reduced[param_id] = False
#Largest partitioned param
largest_partitioned_param_numel = 0
for fp16_partitioned_group in self.fp16_partitioned_groups:
if len(fp16_partitioned_group) > 0:
largest_partitioned_param_numel = max(
largest_partitioned_param_numel,
max([max(tensor.numel(), tensor.ds_numel) for tensor in fp16_partitioned_group]))
print_rank_0(f'Largest partitioned param numel = {largest_partitioned_param_numel}', force=False)
self._setup_for_real_optimizer()
self.grad_position = {}
self.set_grad_positions()
if self.offload_optimizer:
self.norm_for_param_grads = {}
# stores if a partition has been reduced in this step
self.is_partition_reduced = {}
# stores if a grad in a partition has been computed or not
self.is_grad_computed = {}
# will store the averaged gradients required by this partition
self.averaged_gradients = {}
#creates backward hooks for gradient partitioning
###Calls all gather param
self._grad_acc_hooks = []
self._leaf_module_hooks = []
self.create_reduce_and_remove_grad_hooks()
#exit(0)
# we may have a way of fusing dynamic scale. Do not support for now
self.loss_scaler = CreateLossScaler(dtype=self.dtype,
static_loss_scale=static_loss_scale,
dynamic_scaling=dynamic_loss_scale,
dynamic_loss_args=dynamic_loss_args)
self.dynamic_loss_scale = self.loss_scaler.dynamic
self.debug_fp16_grads = [{} for _ in self.fp16_groups]
self._link_all_hp_params()
if dist.get_rank(group=self.dp_process_group) == 0:
see_memory_usage(f"After initializing ZeRO optimizer", force=True)
def destroy(self):
self.parameter_offload.destroy()
for hook in self._grad_acc_hooks:
hook.remove()
for hook in self._leaf_module_hooks:
hook.remove()
print_rank_0("Removed grad acc hooks", force=False)
del self.__ipg_bucket_flat_buffer
def initialize_ds_offload(
self,
module,
timers,
ds_config,
overlap_comm,
prefetch_bucket_size,
max_reuse_distance,
max_live_parameters,
param_persistence_threshold,
model_persistence_threshold,
dp_process_group,
offload_param_config,
mpu,
zero_param_parallel_group,
zero_quantized_weights,
zero_quantized_nontrainable_weights,
):
return DeepSpeedZeRoOffload(module=module,
timers=timers,
ds_config=ds_config,
overlap_comm=overlap_comm,
prefetch_bucket_size=prefetch_bucket_size,
max_reuse_distance=max_reuse_distance,
max_live_parameters=max_live_parameters,
param_persistence_threshold=param_persistence_threshold,
model_persistence_threshold=model_persistence_threshold,
dp_process_group=dp_process_group,
offload_param_config=offload_param_config,
mpu=mpu,
zero_param_parallel_group=zero_param_parallel_group,
zero_quantized_weights=zero_quantized_weights,
zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights)
def _get_trainable_parameter_groups(self):
param_groups = []
PARAMS_KEY = "params"
for param_group in self.optimizer.param_groups:
trainable_params = [p for p in param_group[PARAMS_KEY] if p.requires_grad]
if len(trainable_params) == 0:
continue
trainable_param_group = {}
for key in param_group.keys():
if key == PARAMS_KEY:
trainable_param_group[PARAMS_KEY] = trainable_params
else:
trainable_param_group[key] = param_group[key]
param_groups.append(trainable_param_group)
return param_groups
def _set_zero_group_parallelism(self):
groups._create_zero_param_parallel_group(self.zero_hpz_partition_size)
def invalidate_secondary_tensor(self):
for fpg in self.fp16_groups:
for param in fpg:
if param.ds_secondary_tensor is not None:
param.ds_secondary_tensor = None
def _setup_for_real_optimizer(self):
see_memory_usage("Before creating fp32 partitions", force=True)
self._create_fp32_partitions()
see_memory_usage("After creating fp32 partitions", force=True)
dist.barrier()
# To support pipelined optimizer swapping
self._create_next_swappable_fp32_groups()
see_memory_usage("Before initializing optimizer states", force=True)
self.initialize_optimizer_states()
see_memory_usage("After initializing optimizer states", force=True)
dist.barrier()
if dist.get_rank() == 0:
logger.info(f"optimizer state initialized")
# IPG
if self.contiguous_gradients:
self.__ipg_bucket_flat_buffer: Tensor = torch.empty(self.reduce_bucket_size,
dtype=self.dtype,
device=get_accelerator().current_device_name())
self.grad_partitions_flat_buffer = None
self.__param_id_to_grad_partition: Dict[int, Tensor] = {}
all_params = list(itertools.chain.from_iterable(self.fp16_groups))
self.grad_partitions_flat_buffer: Tensor = torch.zeros(sum(p.partition_numel() for p in all_params),
dtype=self.gradient_accumulation_dtype,
device=self.device)
if self.offload_optimizer_pin_memory:
self.grad_partitions_flat_buffer = get_accelerator().pin_memory(self.grad_partitions_flat_buffer)
offset = 0
for param in all_params:
self.__param_id_to_grad_partition[param.ds_id] = self.grad_partitions_flat_buffer.narrow(
0, offset, param.partition_numel())
offset += param.partition_numel()
def _link_all_hp_params(self):
for p in self.module.parameters():
p._z3_optimizer = self
def set_lr(self, lr):
"""Set the learning rate."""
for param_group in self.optimizer.param_groups:
param_group["lr"] = lr
def get_lr(self):
"""Return the current learning rate."""
return self.optimizer.param_groups[0]["lr"]
# TODO. factor out to a utility outside of stage3
@staticmethod
def defragment(tensors: List[Tensor]) -> Tensor:
"""move provided tensors into a contiguous flat buffer, with some additional
measures taken to reduce memory fragmentation"""
assert len(set(t.dtype for t in tensors)) == 1
assert len(set(t.device for t in tensors)) == 1
cpu_buffer = torch.empty(sum(p.numel() for p in tensors),
dtype=get_only_unique_item(t.dtype for t in tensors),
device="cpu")
tensor_infos: List[Tuple[Tensor, int, int]] = []
orig_device = get_only_unique_item(t.device for t in tensors)
offset = 0
for tensor in tensors:
tensor_numel = tensor.numel()
# move the tensor from device memory to host memory
cpu_buffer.narrow(0, offset, tensor_numel).copy_(tensor)
tensor.data = torch.empty(0, dtype=tensor.dtype, device=tensor.device)
# record some data so we can restore the device tensor later
tensor_infos.append((tensor, offset, tensor_numel))
offset += tensor_numel
gc.collect()
get_accelerator().empty_cache()
# copy tensors (now flattened and contiguous) back to GPU
device_buffer = cpu_buffer.to(orig_device)
# restore device tensors
for tensor, offset, tensor_numel in tensor_infos:
tensor.data = device_buffer.narrow(0, offset, tensor_numel)
return device_buffer
def _get_param_coordinator(self, training):
return self.parameter_offload.get_param_coordinator(training)
def _configure_offloading(self, offload_optimizer_config, offload_param_config):
###################### offload optimizer setup ##################################
if offload_optimizer_config is not None and offload_optimizer_config.device != OffloadDeviceEnum.none:
self.offload_optimizer = True
self.offload_optimizer_pin_memory = offload_optimizer_config.pin_memory
self.swap_optimizer = offload_optimizer_config.device == OffloadDeviceEnum.nvme
self.offload_optimizer_fast_init = offload_optimizer_config.fast_init
###################### offload param setup ##################################
if offload_param_config is not None and offload_param_config.device != OffloadDeviceEnum.none:
self.offload_param = True
self.offload_param_pin_memory = offload_param_config.pin_memory
self.params_in_nvme_and_cpu = offload_param_config.device == OffloadDeviceEnum.nvme
self.max_params_in_cpu = offload_param_config.max_in_cpu
print_rank_0(
f"FP16 params swapping is {self.params_in_nvme_and_cpu}, Max params in CPU is {self.max_params_in_cpu}",
force=False)
def _configure_tensor_swapping(self, offload_optimizer_config, aio_config):
nvme_swap_folder = os.path.join(offload_optimizer_config.nvme_path, 'zero_stage_3')
os.makedirs(nvme_swap_folder, exist_ok=True)
if dist.get_rank() == 0:
logger.info(f'Tensor Swapping: Adding optimizer tensors')
swapper_type = PipelinedOptimizerSwapper if offload_optimizer_config.pipeline else PartitionedOptimizerSwapper
self.optimizer_swapper = swapper_type(swap_config=offload_optimizer_config,
aio_config=aio_config,
base_folder=nvme_swap_folder,
optimizer=self.optimizer,
largest_numel=max(self.fp16_partitioned_groups_flat_numel),
device=self.device,
dtype=torch.float32,
timers=self.timers)
@property
def elements_in_ipg_bucket(self):
return sum(p.ds_numel for p in self.params_in_ipg_bucket)
def _move_to_flat_buffer(self, param_list, flat_buffer, avoid_copy=False):
'''If flat buffer is None then the parameters in the param_list are
not copied to the flat buffer. This is because they exceed the number of max_params_in_cpu
Some of these parameters may already be in CPU in unflattened buffers
or they maybe in GPU, or they maybe in NVME. If they are in NVME, then
they will be marked as NOT_AVAILABLE, and will be moved to CPU when they are
needed during training.'''
if flat_buffer is None:
# this dst buffer is on NVMe, so skip this
return
start = 0
for param in param_list:
src = param.ds_tensor
dest = flat_buffer.narrow(0, start, src.ds_numel)
start = start + src.ds_numel
'''if the parameter was initialized in nvme then bring it to the destination buffer directly'''
if src.status == PartitionedParamStatus.NOT_AVAILABLE:
print_rank_0(
f"Swapping in {param.ds_id} with partition size {param.partition_numel()} permanently to CPU")
param.nvme_swapper.swap_into_buffer(param, dest)
src.data = dest.data
src.status = PartitionedParamStatus.AVAILABLE
else:
assert src.status == PartitionedParamStatus.AVAILABLE, "Partitioned Param must be available here"
if not avoid_copy:
dest.data.copy_(src.data)
src.data = dest.data
# Final location must be gpu/cpu in this case
param.ds_tensor.final_location = 'not-nvme'
def _create_param_groups_fp16_flat_cpu_memory(self):
aggregate_params_count = 0
for j, param_group in enumerate(self.trainable_param_groups):
params_in_group = sum([p.partition_numel() for p in param_group['params']])
flat_buffer_size = params_in_group
if self.params_in_nvme_and_cpu and \
aggregate_params_count + params_in_group > self.max_params_in_cpu:
flat_buffer_size = max(0, self.max_params_in_cpu - aggregate_params_count)
aggregate_params_count += params_in_group
if flat_buffer_size > 0:
print_rank_0(f"group {j} flat buffer size {flat_buffer_size}", force=False)
self.param_groups_fp16_flat_cpu_memory.append(get_accelerator().pin_memory(
torch.empty(int(flat_buffer_size), dtype=self.dtype)))
else:
print_rank_0(f"No flat buffer size. Param group size was {params_in_group}", force=False)
self.param_groups_fp16_flat_cpu_memory.append(torch.empty(1, dtype=self.dtype))
def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups):
dist.barrier()
param_groups: List[List[Parameter]] = tuple(
self._create_fp16_sub_groups(param_group["params"]) for param_group in fp16_param_groups)
# bookkeeping related to param groups
for param_group_idx, param_group in enumerate(param_groups):
for sub_group in param_group:
sub_group_idx = len(self.fp16_groups)
# record sub group and partitions
self.fp16_groups.append(sub_group)
self.fp16_partitioned_groups.append([param.ds_tensor for param in sub_group])
# record sub group -> group mapping
self.sub_group_to_group_id[sub_group_idx] = param_group_idx
# record total elements of parameter partitions in sub group
self.fp16_partitioned_groups_flat_numel.append(sum(p.partition_numel() for p in sub_group))
# record ds_ids of parameter partitions in sub group
self.fp16_partitioned_groups_flat_id.append([p.ds_id for p in sub_group])
# record padding required to align group to world size (only applies to last rank)
rank_requires_padding = dist.get_rank(
self.dp_process_group) == dist.get_world_size(self.dp_process_group) - 1
self.groups_padding.append([p.padding_size() if rank_requires_padding else 0 for p in sub_group])
# move parameters to flattened buffer
if not self.offload_param: # partitioned params remain in GPU during training
# move parameter partitions into a single contiguous flat buffer
parameter_partitions: List[Tensor] = []
for sub_group in self.fp16_groups:
for param in sub_group:
parameter_partitions.append(param.ds_tensor)
device_buffer = __class__.defragment(parameter_partitions)
# setup flat buffers per subgroup, these are each just sections of the
# contiguous flat buffer for all parameters that we created earlier
offset = 0
for sub_group in self.fp16_groups:
sub_group_numel = sum(param.partition_numel() for param in sub_group)
self.fp16_partitioned_groups_flat.append(device_buffer.narrow(0, offset, sub_group_numel))
offset += sub_group_numel
else: # partitioned params offloaded to CPU when not in use
# create a flat CPU memory allocation for each param group
self._create_param_groups_fp16_flat_cpu_memory()
for param_group_idx, param_group in enumerate(param_groups):
flat_offset = 0
for i, sub_group in enumerate(param_group):
total_elements = sum(p.partition_numel() for p in sub_group)
print_rank_0(f"Params in nvme and cpu {self.params_in_nvme_and_cpu}")
#Flat buffer may not be available for parameters that reside in NVME
if not self.params_in_nvme_and_cpu or flat_offset + total_elements <= self.param_groups_fp16_flat_cpu_memory[
param_group_idx].numel():
fp16_partitioned_group_flat = self.param_groups_fp16_flat_cpu_memory[param_group_idx].narrow(
0, flat_offset, total_elements)
print_rank_0(
f"Creating a flat buffer for subgroup {i} requiring {total_elements} elements, and cumulative CPU elements {flat_offset + total_elements}",
force=False)
elif self.params_in_nvme_and_cpu:
fp16_partitioned_group_flat = None
print_rank_0(f"No flat buffer for sub group {i} of {total_elements} elements", force=False)
else:
assert False, "Either params are in nvme, or they are in CPU memory. This code path should not be triggered. Please see you max_params_in_cpu and params_in_nvme configs"
self.fp16_partitioned_groups_flat.append(fp16_partitioned_group_flat)
flat_offset += total_elements
self._move_to_flat_buffer(sub_group,
fp16_partitioned_group_flat,
avoid_copy=not self.offload_param)
# if necessary, create a pinned memory buffer to be used for swapping out
# params to NVME after optimizer step
should_create_fp16_flat_reuse_buffer = any(flattened_partition_group is None
for flattened_partition_group in self.fp16_partitioned_groups_flat)
if should_create_fp16_flat_reuse_buffer:
max_partition_numel, largest_partition_numel = 0, None
for sub_group in self.fp16_groups:
total_elements = sum(t.partition_numel() for t in sub_group)
if total_elements > max_partition_numel:
largest_partition_numel = [t.ds_numel for t in sub_group]
max_partition_numel = total_elements
assert len(largest_partition_numel) > 0, f'Unexpected that largest partition is empty'
self.fp16_groups[0][0].nvme_swapper.reserve_partitioned_swap_space(largest_partition_numel)
def _swap_in_sub_group_to_flat_buffer(self, flat_buffer, sub_group_id):
offset = 0
elements_in_sub_group = sum([t.ds_numel for t in self.fp16_partitioned_groups[sub_group_id]])
assert (flat_buffer.numel() == elements_in_sub_group)
for param, partitioned_param in zip(self.fp16_groups[sub_group_id],
self.fp16_partitioned_groups[sub_group_id]):
dest = flat_buffer.narrow(0, offset, partitioned_param.ds_numel)
if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE:
print_rank_0(
f"Swapping in {param.ds_id} with elements {param.ds_numel} and partition {param.partition_numel()}"
)
param.nvme_swapper.swap_in([param], async_op=False)
dest.data.copy_(partitioned_param.data)
param.nvme_swapper.remove_partition_and_release_buffers([param])
print_rank_0(f"Swapping in {param.ds_id} done")
else:
dest.data.copy_(partitioned_param.data)
offset += partitioned_param.ds_numel
def _create_next_swappable_fp32_groups(self):
reverse_order_indices = [i for i in range(len(self.fp32_partitioned_groups_flat))]
reverse_order_indices.reverse()
next_group = None
for i in reverse_order_indices:
self.next_swappable_fp32_partitioned_groups.append(next_group)
if self._swappable_optimizer_subgroup(i):
next_group = self.fp32_partitioned_groups_flat[i]
self.next_swappable_fp32_partitioned_groups.reverse()
def _get_sub_group_partitions(self, sub_group_id):
sub_group_partitions = []
for param, partitioned_param in zip(self.fp16_groups[sub_group_id],
self.fp16_partitioned_groups[sub_group_id]):
if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE:
swap_path = param.nvme_swapper.get_path(param, True)
sub_group_partitions.append((partitioned_param, param.partition_numel(), swap_path))
else:
sub_group_partitions.append((partitioned_param, partitioned_param.ds_numel, None))
return sub_group_partitions
def _create_fp32_partitions(self):
cpu_memory_usage = 0
cpu_memory_sub_groups = 0
nvme_memory_usage = 0
num_swappable_partitions = 0
num_swap_from_nvme_partitions = 0
num_swap_from_cpu_partitions = 0
swap_from_nvme_memory_usage = 0
swap_from_cpu_memory_usage = 0
GIGA_BYTES = (1024**3)
swappable_fp32_tensors = []
swappable_fp16_src_tensors = []
nvme_fp16_partitions_info = []
nvme_fp16_num_elems = []
nvme_fp32_dest_tensors = []
fp32_element_size = torch.tensor([], dtype=torch.float32).element_size()
# Assign portion of subgroup to cpu, the other to gpu.
if self.offload_optimizer:
self.subgroup_to_device = {}
sub_group_size = len(self.fp16_partitioned_groups_flat)
# print(f"Partial offload sub_group_size is {sub_group_size}, ratio is {self.partial_offload}\n")
for i in range(sub_group_size):
if i < int(self.partial_offload * sub_group_size):
self.subgroup_to_device[i] = 'cpu'
else:
self.subgroup_to_device[i] = get_accelerator()._name
for i, tensor in enumerate(self.fp16_partitioned_groups_flat):
num_elements = self.fp16_partitioned_groups_flat_numel[i]
# a partition of the fp32 master weights that will be updated by this process
if self._swappable_optimizer_subgroup(i):
self.fp32_partitioned_groups_flat.append(torch.Tensor())
nvme_memory_usage += (fp32_element_size * num_elements)
num_swappable_partitions += 1
if self.params_in_nvme_and_cpu and tensor is None:
num_swap_from_nvme_partitions += 1
swap_from_nvme_memory_usage += (fp32_element_size * num_elements)
if self.offload_optimizer_fast_init:
sub_group_partitions = self._get_sub_group_partitions(i)
nvme_fp16_partitions_info.append(sub_group_partitions)
nvme_fp16_num_elems.append(num_elements)
nvme_fp32_dest_tensors.append(self.fp32_partitioned_groups_flat[i])
else:
unpinned_fp32_buffer = torch.empty(num_elements, device=self.device, dtype=torch.float)
self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i)
self.optimizer_swapper.initialize_parameters(parameters=[self.fp32_partitioned_groups_flat[i]],
src_tensors=[unpinned_fp32_buffer])
else:
num_swap_from_cpu_partitions += 1
swap_from_cpu_memory_usage += (fp32_element_size * num_elements)
swappable_fp32_tensors.append(self.fp32_partitioned_groups_flat[i])
swappable_fp16_src_tensors.append(self.fp16_partitioned_groups_flat[i])
else:
cpu_memory_usage += (fp32_element_size * num_elements)
cpu_memory_sub_groups += 1
if self.params_in_nvme_and_cpu and tensor is None:
unpinned_fp32_buffer = torch.empty(num_elements, device=self.device, dtype=torch.float)
self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i)
self.fp32_partitioned_groups_flat.append(unpinned_fp32_buffer)
else:
if self.offload_optimizer:
self.fp32_partitioned_groups_flat.append(self.fp16_partitioned_groups_flat[i].to(
self.subgroup_to_device[i]).clone().float().detach())
else:
self.fp32_partitioned_groups_flat.append(self.fp16_partitioned_groups_flat[i].to(
self.device).clone().float().detach())
self.fp32_partitioned_groups_flat[i].requires_grad = True # keep this in case internal optimizer uses it
ds_id_begin = str(self.fp16_partitioned_groups_flat_id[i][0])
ds_id_end = str(self.fp16_partitioned_groups_flat_id[i][-1])
self.fp32_partitioned_groups_flat[i].ds_id = ds_id_begin + '_' + ds_id_end
if len(swappable_fp32_tensors) > 0:
self.optimizer_swapper.initialize_parameters(parameters=swappable_fp32_tensors,
src_tensors=swappable_fp16_src_tensors)
if len(nvme_fp32_dest_tensors) > 0:
fp16_pinned_buffers = self.fp16_groups[0][0].nvme_swapper.reserve_available_buffers()
assert len(fp16_pinned_buffers) > 0
self.optimizer_swapper.initialize_from_swapped_fp16_params(fp16_partitions_info=nvme_fp16_partitions_info,
fp16_num_elems=nvme_fp16_num_elems,
fp16_pinned_buffers=fp16_pinned_buffers,
fp32_parameters=nvme_fp32_dest_tensors)
self.fp16_groups[0][0].nvme_swapper.release_reserved_buffers()
nvme_gigabytes = nvme_memory_usage / GIGA_BYTES
print_rank_0(f'Swappable FP32 Partitions: count={num_swappable_partitions} size={nvme_gigabytes:5.2f} GB',
force=False)
if self.params_in_nvme_and_cpu:
print_rank_0(
f'Swap from NVMe Partitions: count = {num_swap_from_nvme_partitions}, size = {swap_from_nvme_memory_usage/GIGA_BYTES:5.2f}GB',
force=False)
print_rank_0(
f'Swap from CPU Partitions: count = {num_swap_from_cpu_partitions}, size = {swap_from_cpu_memory_usage/GIGA_BYTES:5.2f}GB',
force=False)
cpu_memory_gigabytes = cpu_memory_usage / GIGA_BYTES
print_rank_0(f'In-Memory FP32 Partitions: count={cpu_memory_sub_groups} size={cpu_memory_gigabytes:5.2f} GB',
force=False)
# Clear for on-the-fly population before the optimizer step
for param_group in self.optimizer.param_groups:
param_group['params'] = []
def _create_fp16_sub_groups(self, params_group):
params_group_numel = sum([param.partition_numel() for param in params_group])
sub_group_size = self.sub_group_size
if sub_group_size is None or sub_group_size >= params_group_numel:
return [params_group]
sub_groups = []
sub_group = []
local_sub_group_size = 0
for param in params_group:
sub_group.append(param)
local_sub_group_size += param.partition_numel()
if local_sub_group_size >= sub_group_size or id(param) == id(params_group[-1]):
sub_groups.append(sub_group)
sub_group = []
local_sub_group_size = 0
return sub_groups
def _release_ipg_buffers(self):
if self.contiguous_gradients:
self.ipg_buffer = None
def _optimizer_step(self, sub_group_id):
param_group_id = self.sub_group_to_group_id[sub_group_id]
fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
if self.offload_optimizer:
cur_device = self.subgroup_to_device[sub_group_id]
if cur_device == 'cpu':
self.optimizer.param_groups[param_group_id]['params'] = [fp32_param]
cpu_loss = self.optimizer.step()
self.optimizer.param_groups[param_group_id]['params'] = []
else:
self.backup_optimizer.param_groups[param_group_id]['params'] = [fp32_param]
gpu_loss = self.backup_optimizer.step()
self.backup_optimizer.param_groups[param_group_id]['params'] = []
else:
self.optimizer.param_groups[param_group_id]['params'] = [fp32_param]
self.optimizer.step()
self.optimizer.param_groups[param_group_id]['params'] = []
def _swappable_optimizer_subgroup(self, sub_group_id):
if not self.swap_optimizer:
return False
return self.optimizer_swapper.swappable_tensor(None,
numel=self.fp16_partitioned_groups_flat_numel[sub_group_id])
def _partitioned_params_swap_out(self, i):
offset = 0
fp32_param = self.fp32_partitioned_groups_flat[i]
assert fp32_param is not None, \
f'fp32 parameters of sub_group {i} is None'
swap_fp16_params = []
swap_fp32_params = []
for param, partitioned_param in zip(self.fp16_groups[i], self.fp16_partitioned_groups[i]):
src = fp32_param.narrow(0, offset, partitioned_param.ds_numel)
if partitioned_param.status == PartitionedParamStatus.AVAILABLE:
partitioned_param.data.copy_(src.data)
else:
swap_fp32_params.append(src)
swap_fp16_params.append(param)
offset += partitioned_param.ds_numel
if len(swap_fp16_params):
swap_fp16_params[0].nvme_swapper.swap_out_partitioned_params(dst_fp16_params=swap_fp16_params,
src_fp32_params=swap_fp32_params)
def initialize_optimizer_states(self):
num_subgroups = len(self.fp16_groups)
largest_numel = max([sum([p.ds_numel for p in psg]) for psg in self.fp16_partitioned_groups])
gradient_dtype = self.fp32_partitioned_groups_flat[0].dtype
gradient_buffer = torch.zeros(int(largest_numel), dtype=gradient_dtype, device=self.device)
timer_names = set()
# State initialization for the Adagrad optimizer occurs at construction as opposed to other optimizers
# which do lazy initialization of the state at the first call to step.
is_adagrad = isinstance(self.optimizer, torch.optim.Adagrad)
if self.swap_optimizer:
self.optimizer_swapper.init_timers()
timer_names.add(INIT_OPTIMIZER_TIMER)
self.timers(INIT_OPTIMIZER_TIMER).start()
for i, group in enumerate(self.fp16_groups):
swappable_optimizer_subgroup = self._swappable_optimizer_subgroup(i)
swappable_param_subgroup = self.fp16_partitioned_groups_flat[i] is None