Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepspeed/checkpoint/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
UNIVERSAL_CHECKPOINT_VERSION_VALUE = 0.2

# Vocabulary padding
VOCAB_DIVISIBILITY_PADDING_TENSOR = 'vocab_divisibility_padding_tensor'
VOCAB_TENSOR = 'vocab_tensor'
PADDED_VOCAB_SIZE = 'padded_vocab_size'
ORIGINAL_VOCAB_SIZE = 'original_vocab_size'

Expand Down
60 changes: 41 additions & 19 deletions deepspeed/checkpoint/ds_to_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@
PARAM_SHAPES,
PARAM,
CAT_DIM,
VOCAB_DIVISIBILITY_PADDING_TENSOR,
ORIGINAL_VOCAB_SIZE,
VOCAB_TENSOR,
UNIVERSAL_CHECKPOINT_INFO,
VOCABULARY_PARAMETER_PATTERNS,
PIPELINE_REPLICATED_PARAMETER_PATTERNS,
Expand Down Expand Up @@ -55,6 +54,10 @@ def parse_arguments():
parser.add_argument('--keep_temp_folder',
action='store_true',
help='Preserve temporary folder of intermediate checkpoint slice files. Useful for debugging.')
parser.add_argument('--no_strict',
dest='strict',
action='store_false',
help='Do not perform validity checks on converted checkpoint.')
args = parser.parse_args()
print(f'args = {args}')
return args
Expand Down Expand Up @@ -149,15 +152,8 @@ def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape):
return slices


def _get_vocab_divisibility_padding_tensor(universal_checkpoint_info, padded_vocab_tensor):
original_vocab_size = universal_checkpoint_info.get(ORIGINAL_VOCAB_SIZE)
if padded_vocab_tensor.shape[0] > original_vocab_size:
return padded_vocab_tensor[-1]
else:
return torch.zeros(padded_vocab_tensor.shape[1])


def merge_tp_slices(ds_checkpoint, dir, slice_dir, tp_degree, name_and_shape):

name, shape = name_and_shape
slice_base_path = os.path.join(slice_dir, name)
param_base_path = os.path.join(dir, name)
Expand All @@ -167,39 +163,53 @@ def merge_tp_slices(ds_checkpoint, dir, slice_dir, tp_degree, name_and_shape):
parameters_to_average = universal_checkpoint_info.get(PARAMETER_TO_AVERAGE_PATTERNS, [])
parameters_with_row_parallelism = universal_checkpoint_info.get(PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, [])
vocabulary_parameters = universal_checkpoint_info.get(VOCABULARY_PARAMETER_PATTERNS, [])
unmatched_patterns = set(replicated_parameters + parameters_to_average + parameters_with_row_parallelism +
vocabulary_parameters)

def get_matched_pattern(patterns_, name_):
matched_ = [pattern_ for pattern_ in patterns_ if re.match(pattern_, name_)]
assert len(matched_) <= 1, f'Got more than one matching patterns={matched_} for {name_}'
if matched_:
pattern_ = matched_[0]
unmatched_patterns.discard(pattern_)
return pattern_
return None

for state in ("fp32", "exp_avg", "exp_avg_sq"):
slices = _merge_zero_shards(slice_base_path, state, tp_degree, shape)
final_path = os.path.join(param_base_path, f"{state}.pt")

#print(f"Expected shape: {shape}")
#print(f"Fragment sizes:", list(frag.shape for frag in slices))
ckpt_dict = {}
if any(re.match(pattern, name) for pattern in replicated_parameters):
if get_matched_pattern(replicated_parameters, name):
if len(slices) > 1:
assert all([slices[0].equal(other_slice) for other_slice in slices[1:]])
param = slices[0]
# print(f'replicate {name} using first slice')
elif any(re.match(pattern, name) for pattern in parameters_to_average):
elif get_matched_pattern(parameters_to_average, name):
param = sum(slices) / len(slices)
# print(f'merge {name} using average')
else:
cat_dim = 1 if any(re.match(pattern, name) for pattern in parameters_with_row_parallelism) else 0
cat_dim = 1 if get_matched_pattern(parameters_with_row_parallelism, name) else 0
# print(f"merge {name} with CAT DIM: {cat_dim}")
param = torch.cat(slices, dim=cat_dim)
ckpt_dict[CAT_DIM] = cat_dim

if any(re.match(pattern, name) for pattern in vocabulary_parameters):
if get_matched_pattern(vocabulary_parameters, name):
#print(f"Before {param.shape=}")
# strip padding
#param = _strip_vocab_padding(ds_checkpoint, param)
ckpt_dict[VOCAB_DIVISIBILITY_PADDING_TENSOR] = _get_vocab_divisibility_padding_tensor(
universal_checkpoint_info, param)
original_vocab_size = universal_checkpoint_info['original_vocab_size']
param = param[:original_vocab_size, :]
ckpt_dict[VOCAB_TENSOR] = True
#print(f"After {param.shape=}")

#print(f"Final shape: {param.shape}")
ckpt_dict[PARAM] = param
_save_checkpoint(final_path, ckpt_dict)

return unmatched_patterns


def _get_chunks(l, n):
for i in range(0, len(l), n):
Expand All @@ -208,10 +218,13 @@ def _get_chunks(l, n):

def _do_parallel_work(do_work, work_chunks, num_workers):
pool = multiprocessing.Pool(num_workers)
results = []
for batch in tqdm.tqdm(work_chunks):
pool.map(do_work, batch)
res = pool.map(do_work, batch)
results.extend(res)
pool.close()
pool.join()
return results


def _extract_zero_shard_files(args, ds_checkpoint, temp_dir):
Expand All @@ -232,7 +245,16 @@ def _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir):
#pprint(work_chunks)
zero_output_folder = os.path.join(args.output_folder, "zero")
do_work = partial(merge_tp_slices, ds_checkpoint, zero_output_folder, temp_dir, ds_checkpoint.tp_degree)
_do_parallel_work(do_work, work_chunks, args.num_merge_workers)
unmatched_patterns_lists = _do_parallel_work(do_work, work_chunks, args.num_merge_workers)

# verify that all patterns were used
# if a pattern was not used by any of the workers, then it was not used at all -> assert/alert
sets = [set(lst) for lst in unmatched_patterns_lists]
unmatched_patterns = list(set.intersection(*sets))
if args.strict:
assert not unmatched_patterns, f'Unused patterns={unmatched_patterns} while merging tp slices'
elif unmatched_patterns:
print(f'Warning: Unused patterns={unmatched_patterns} while merging tp slices')


def _save_optimizer_state(args, ds_checkpoint):
Expand Down
15 changes: 5 additions & 10 deletions deepspeed/checkpoint/universal_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import torch
import types
from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_DIVISIBILITY_PADDING_TENSOR, CAT_DIM)
from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_TENSOR, CAT_DIM)


def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
Expand Down Expand Up @@ -43,21 +43,16 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
# the converter to universal currently strips the original padding completely so the saved
# weight is padding-free and we just need to add new padding depending on the target TP
# degree
vocab_divisibility_padding_tensor = ckpt_dict.get(VOCAB_DIVISIBILITY_PADDING_TENSOR, None)
if vocab_divisibility_padding_tensor is not None:
is_vocab_tensor = ckpt_dict.get(VOCAB_TENSOR, False)
if is_vocab_tensor:
# In the absence of data passed from the user wrt new padded vocab specific to tp degree
# we can again derive that data by reverse engineering the target shapes like so:
padded_target_vocab_size = self.shape[0] * tp_world_size
assert padded_target_vocab_size >= full_hp_param.shape[0], \
f'Vocab tensor padded size {padded_target_vocab_size} < loaded universal size {full_hp_param.shape[0]}'
if padded_target_vocab_size > full_hp_param.shape[0]:
# Need to expand
padding_size = padded_target_vocab_size - full_hp_param.shape[0]
# Implement the following concat in efficient way using pad
#full_hp_param = torch.cat((full_hp_param, padding_tensor), 0)
full_hp_param = torch.nn.functional.pad(full_hp_param, (0, 0, 0, padding_size), "constant", 0)
full_hp_param[:-padding_size, :] = vocab_divisibility_padding_tensor
else:
# Need to shrink or keep the same
full_hp_param = full_hp_param[:padded_target_vocab_size, :]

full_param_numel = full_hp_param.numel()
tp_slice_numel = self.numel()
Expand Down
8 changes: 8 additions & 0 deletions deepspeed/runtime/activation_checkpointing/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,14 @@ def model_parallel_cuda_manual_seed(seed):
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, model_parallel_seed)


def model_parallel_reconfigure_tp_seed(seed):
global mpu
tp_rank = bwc_tensor_model_parallel_rank(mpu)
model_parallel_seed = seed + 2718 + tp_rank
with _CUDA_RNG_STATE_TRACKER.fork():
get_accelerator().manual_seed(model_parallel_seed)


def get_partition_start(item):
global mp_rank, mp_size, mp_group
size = item.numel()
Expand Down
2 changes: 0 additions & 2 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def __init__(self,
self.fp32_groups_gradient_flat_partition = []
self.fp32_groups_has_gradients = []

self.step_count = 0
self.group_paddings = []

if self.using_real_optimizer:
Expand Down Expand Up @@ -252,7 +251,6 @@ def step(self, closure=None):
self.update_lp_params()

self.clear_hp_grads()
self.step_count += 1

def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs):
"""Perform a backward pass and copy the low-precision gradients to the
Expand Down
20 changes: 20 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2726,6 +2726,8 @@ def load_checkpoint(self,

if self.load_universal_checkpoint():
self.optimizer.update_lp_params()
if load_zero_checkpoint:
self.update_optimizer_step(step=client_states['iteration'] + 1)

return load_path, client_states

Expand Down Expand Up @@ -2903,6 +2905,24 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True):
logger.info(f"loading {len(zero_sd_list)} zero partition checkpoints for rank {self.global_rank}")
return True

def update_optimizer_step(self, step):

def set_step(d):
if isinstance(d['step'], torch.Tensor):
d['step'] = torch.tensor(step, dtype=d['step'].dtype, device=d['step'].device)
else:
d['step'] = step

optimizer = self.optimizer
base_optimizer = optimizer.optimizer
state = base_optimizer.state
for group in optimizer.param_groups:
if 'step' in group:
set_step(group)
for p in group['params']:
if p in state and len(state[p]) > 0 and 'step' in state[p]:
set_step(state[p])

def _get_mp_rank_zero_checkpoint_names(self, load_dir, tag, mp_rank, dp_world_size, bf16_mode):
zero_ckpt_names = []
for dp_rank in range(dp_world_size):
Expand Down
3 changes: 0 additions & 3 deletions deepspeed/runtime/fp16/fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def __init__(self,

self.clip_grad = clip_grad
self.norm_type = 2
self.step_count = 0

if required_torch_version(max_version=0.4):
self.clip_grad_norm = torch.nn.utils.clip_grad_norm
Expand Down Expand Up @@ -289,8 +288,6 @@ def step(self, closure=None):

self.timers.log(STEP_TIMERS)

self.step_count += 1

return self.overflow

def _get_norm_with_moe_layers(self, all_groups_norm):
Expand Down