From 53182531ed7e33d980ded64fad2fb209d9919672 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Sat, 3 Sep 2022 12:16:07 -0400 Subject: [PATCH] Refactor universal checkpointing and tensor fragments (#2253) * Refactor universal checkpointing and tensor fragments * Formatting --- deepspeed/checkpoint/__init__.py | 2 + deepspeed/checkpoint/constants.py | 3 + deepspeed/checkpoint/universal_checkpoint.py | 110 ++++++++ deepspeed/runtime/bf16_optimizer.py | 251 ++----------------- deepspeed/utils/__init__.py | 2 + deepspeed/utils/mixed_precision_linkage.py | 45 ++++ deepspeed/utils/tensor_fragment.py | 105 ++++++++ 7 files changed, 284 insertions(+), 234 deletions(-) create mode 100644 deepspeed/checkpoint/universal_checkpoint.py create mode 100644 deepspeed/utils/mixed_precision_linkage.py create mode 100644 deepspeed/utils/tensor_fragment.py diff --git a/deepspeed/checkpoint/__init__.py b/deepspeed/checkpoint/__init__.py index edb424e9dfa8..407a9b50a7bb 100644 --- a/deepspeed/checkpoint/__init__.py +++ b/deepspeed/checkpoint/__init__.py @@ -11,3 +11,5 @@ from .reshape_3d_utils import (model_3d_desc, get_model_3d_descriptor) from .zero_checkpoint import ZeROCheckpoint + +from .universal_checkpoint import enable_universal_checkpoint diff --git a/deepspeed/checkpoint/constants.py b/deepspeed/checkpoint/constants.py index dc79df643af2..b46502ceae36 100644 --- a/deepspeed/checkpoint/constants.py +++ b/deepspeed/checkpoint/constants.py @@ -21,8 +21,11 @@ ######################################### # Module checkpoint keys ######################################### +PARAM = 'param' PARAM_SHAPES = 'param_shapes' BUFFER_NAMES = 'buffer_names' +VOCAB_DIVISIBILITY_PADDING_TENSOR = 'vocab_divisibility_padding_tensor' +CAT_DIM = "cat_dim" ######################################### # Checkpoint naming constants diff --git a/deepspeed/checkpoint/universal_checkpoint.py b/deepspeed/checkpoint/universal_checkpoint.py new file mode 100644 index 000000000000..f791dec6afa4 --- /dev/null +++ b/deepspeed/checkpoint/universal_checkpoint.py @@ -0,0 +1,110 @@ +""" +Copyright 2022 The Microsoft DeepSpeed Team +""" +import os +import torch +import types + +from .constants import (FP32_WEIGHT_KEY, + PARAM, + VOCAB_DIVISIBILITY_PADDING_TENSOR, + CAT_DIM) + + +def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): + hp_mapping = self._hp_mapping + optim_state_keys = hp_mapping.get_optim_state_keys() + hp_keys = [FP32_WEIGHT_KEY] + optim_state_keys + checkpoint_files = {key: os.path.join(folder, f"{key}.pt") for key in hp_keys} + + for file in checkpoint_files.values(): + assert os.path.isfile(file), f'{file} is not a valid file' + + for key in hp_keys: + ckpt_file = checkpoint_files[key] + ckpt_dict = torch.load(ckpt_file) + full_hp_param = ckpt_dict[PARAM] + + # need to deal with slices that were averaged. + # the opposite of averaging here becomes an exact copy of the first slice + # I thought of 2 ways: + # implementation a. find a way for a client to pass a dict with patterns + # if any(re.search(pattern, folder) for pattern in WEIGHTS_TO_AVERAGE_PATTERNS): + # tp_rank = 0 + # tp_world_size = 1 + # the other approach is to assume that the saved data is correct and if full_hp_param.shape == + # self.shape that means we automatically copy? + # implementation b. + # this version requires no additional data passed from the client + # if the shapes already match it must be slices that were averaged - so we just hack around those + if full_hp_param.shape == self.shape: + tp_rank = 0 + tp_world_size = 1 + + # special case for word_embeddings weights which get padded differently depending on TP degree. + # the converter to universal currently strips the original padding completely so the saved + # weight is padding-free and we just need to add new padding depending on the target TP + # degree + vocab_divisibility_padding_tensor = ckpt_dict.get( + VOCAB_DIVISIBILITY_PADDING_TENSOR, + None) + if vocab_divisibility_padding_tensor is not None: + # In the absence of data passed from the user wrt new padded vocab specific to tp degree + # we can again derive that data by reverse engineering the target shapes like so: + padded_target_vocab_size = self.shape[0] * tp_world_size + if padded_target_vocab_size > full_hp_param.shape[0]: + # Need to expand + padding_tensor = vocab_divisibility_padding_tensor.expand( + padded_target_vocab_size - full_hp_param.shape[0]) + # Implement the following concat in efficient way using pad + #full_hp_param = torch.cat((full_hp_param, padding_tensor), 0) + full_hp_param = torch.nn.functional.pad(full_hp_param, + (0, + 0, + 0, + padding_tensor.shape[0]), + "constant", + 0) + full_hp_param[:-padding_tensor.shape[0], :] = padding_tensor + else: + # Need to shrink or keep the same + full_hp_param = full_hp_param[:padded_target_vocab_size, :] + + full_param_numel = full_hp_param.numel() + tp_slice_numel = self.numel() + # if key == FP32_WEIGHT_KEY and 'word_embeddings.weight' in folder: + # print_rank_0(f'{full_hp_param[:10]=}', force=True) + + + assert full_param_numel == tp_world_size * tp_slice_numel, \ + f'Loading {ckpt_file} full param numel {full_param_numel} != tensor slice numel {tp_slice_numel} * tp_world_size {tp_world_size}' + dst_tensor = hp_mapping.hp_fragment if key == FP32_WEIGHT_KEY else hp_mapping.get_optim_state_fragment( + key) + + # print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}") + # print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}") + + # since when we do many to 1 on tp we cat sometimes on dim=0 and other times on dim=1 we have to do exactly the same in reverse + chunk_dim = ckpt_dict.get(CAT_DIM, 0) + + # this performs the opposite of cat when merging TP slices + tp_hp_slice = full_hp_param.chunk(tp_world_size, chunk_dim)[tp_rank] + tp_hp_slice = tp_hp_slice.flatten() + + lp_frag_address = hp_mapping.lp_fragment_address + tp_hp_fragment = tp_hp_slice.narrow(0, + lp_frag_address.start, + lp_frag_address.numel) + assert dst_tensor.numel() == lp_frag_address.numel, \ + f'Load checkpoint {key} dst_tensor numel {dst_tensor.numel()} != src numel {lp_frag_address.numel}' + + # print(f"{key} SHAPE: {tp_hp_slice.shape=}") + # print(f"{key} SHAPE: {dst_tensor.shape=}") + # print(f"{key} SHAPE: {tp_hp_fragment.shape=}") + dst_tensor.data.copy_(tp_hp_fragment.data) + + +def enable_universal_checkpoint(param_list): + for param in param_list: + param.load_hp_checkpoint_state = types.MethodType(load_hp_checkpoint_state, + param) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index a9988e2c498f..40b5b769bad1 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -21,162 +21,15 @@ is_model_parallel_parameter, see_memory_usage) +from deepspeed.utils import link_hp_params +from deepspeed.checkpoint import enable_universal_checkpoint from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE, SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS, - PARAM_SLICE_MAPPINGS, - FP32_WEIGHT_KEY) - -import types - -from dataclasses import dataclass - - -@dataclass -class fragment_address: - numel: int - start: int - - -@dataclass -class tensor_fragment: - lp_fragment: torch.Tensor - lp_fragment_address: fragment_address - hp_fragment: torch.Tensor - hp_fragment_address: fragment_address - optim_fragment: {} - - def update_hp(self): - self.hp_fragment.data.copy_(self.lp_fragment.data) - - def update_lp(self): - self.lp_fragment.data.copy_(self.hp_fragment.data) - - def get_optim_state_fragment(self, key): - if key in self.optim_fragment: - return self.optim_fragment[key] - else: - raise ValueError(f'{key} not found in optimizer state fragment') - - def get_hp_fragment_address(self): - return self.hp_fragment_address - - def get_optim_state_keys(self): - return list(self.optim_fragment.keys()) - - -def get_full_hp_param(self, optim_state_key=None): - reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten() - if self._hp_mapping is not None: - lp_frag_address = self._hp_mapping.lp_fragment_address - reduce_fragment = torch.narrow(reduce_buffer, - 0, - lp_frag_address.start, - lp_frag_address.numel) - if optim_state_key is None: - hp_fragment = self._hp_mapping.hp_fragment - else: - hp_fragment = self._hp_mapping.get_optim_state_fragment(optim_state_key) - - reduce_fragment.data.copy_(hp_fragment.data) - dist.all_reduce(reduce_buffer, group=self._dp_group) - return reduce_buffer.reshape_as(self) - - -def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): - hp_mapping = self._hp_mapping - optim_state_keys = hp_mapping.get_optim_state_keys() - hp_keys = [FP32_WEIGHT_KEY] + optim_state_keys - checkpoint_files = {key: os.path.join(folder, f"{key}.pt") for key in hp_keys} - - for file in checkpoint_files.values(): - assert os.path.isfile(file), f'{file} is not a valid file' - - for key in hp_keys: - ckpt_file = checkpoint_files[key] - ckpt_dict = torch.load(ckpt_file) - full_hp_param = ckpt_dict['param'] - - # need to deal with slices that were averaged. - # the opposite of averaging here becomes an exact copy of the first slice - # I thought of 2 ways: - # implementation a. find a way for a client to pass a dict with patterns - # if any(re.search(pattern, folder) for pattern in WEIGHTS_TO_AVERAGE_PATTERNS): - # tp_rank = 0 - # tp_world_size = 1 - # the other approach is to assume that the saved data is correct and if full_hp_param.shape == - # self.shape that means we automatically copy? - # implementation b. - # this version requires no additional data passed from the client - # if the shapes already match it must be slices that were averaged - so we just hack around those - if full_hp_param.shape == self.shape: - tp_rank = 0 - tp_world_size = 1 - - # special case for word_embeddings weights which get padded differently depending on TP degree. - # the converter to universal currently strips the original padding completely so the saved - # weight is padding-free and we just need to add new padding depending on the target TP - # degree - vocab_divisibility_padding_tensor = ckpt_dict.get( - 'vocab_divisibility_padding_tensor', - None) - if vocab_divisibility_padding_tensor is not None: - # In the absence of data passed from the user wrt new padded vocab specific to tp degree - # we can again derive that data by reverse engineering the target shapes like so: - padded_target_vocab_size = self.shape[0] * tp_world_size - if padded_target_vocab_size > full_hp_param.shape[0]: - # Need to expand - padding_tensor = vocab_divisibility_padding_tensor.expand( - padded_target_vocab_size - full_hp_param.shape[0]) - # Implement the following concat in efficient way using pad - #full_hp_param = torch.cat((full_hp_param, padding_tensor), 0) - full_hp_param = torch.nn.functional.pad(full_hp_param, - (0, - 0, - 0, - padding_tensor.shape[0]), - "constant", - 0) - full_hp_param[:-padding_tensor.shape[0], :] = padding_tensor - else: - # Need to shrink or keep the same - full_hp_param = full_hp_param[:padded_target_vocab_size, :] - - full_param_numel = full_hp_param.numel() - tp_slice_numel = self.numel() - # if key == FP32_WEIGHT_KEY and 'word_embeddings.weight' in folder: - # print_rank_0(f'{full_hp_param[:10]=}', force=True) - - - assert full_param_numel == tp_world_size * tp_slice_numel, \ - f'Loading {ckpt_file} full param numel {full_param_numel} != tensor slice numel {tp_slice_numel} * tp_world_size {tp_world_size}' - dst_tensor = hp_mapping.hp_fragment if key == FP32_WEIGHT_KEY else hp_mapping.get_optim_state_fragment( - key) - - # print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}") - # print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}") - - # since when we do many to 1 on tp we cat sometimes on dim=0 and other times on dim=1 we have to do exactly the same in reverse - chunk_dim = ckpt_dict.get('cat_dim', 0) - - # this performs the opposite of cat when merging TP slices - tp_hp_slice = full_hp_param.chunk(tp_world_size, chunk_dim)[tp_rank] - tp_hp_slice = tp_hp_slice.flatten() - - lp_frag_address = hp_mapping.lp_fragment_address - tp_hp_fragment = tp_hp_slice.narrow(0, - lp_frag_address.start, - lp_frag_address.numel) - assert dst_tensor.numel() == lp_frag_address.numel, \ - f'Load checkpoint {key} dst_tensor numel {dst_tensor.numel()} != src numel {lp_frag_address.numel}' - - # print(f"{key} SHAPE: {tp_hp_slice.shape=}") - # print(f"{key} SHAPE: {dst_tensor.shape=}") - # print(f"{key} SHAPE: {tp_hp_fragment.shape=}") - dst_tensor.data.copy_(tp_hp_fragment.data) + PARAM_SLICE_MAPPINGS) class BF16_Optimizer(ZeROOptimizer): @@ -327,8 +180,13 @@ def _setup_for_real_optimizer(self): # Need optimizer states initialized before linking lp to optimizer state self._link_all_hp_params() + self._enable_universal_checkpoint() self._param_slice_mappings = self._create_param_mapping() + def _enable_universal_checkpoint(self): + for lp_param_group in self.bf16_groups: + enable_universal_checkpoint(param_list=lp_param_group) + def _create_param_mapping(self): param_mapping = [] for i, _ in enumerate(self.optimizer.param_groups): @@ -344,93 +202,18 @@ def _create_param_mapping(self): def _link_all_hp_params(self): dp_world_size = dist.get_world_size(group=self.dp_process_group) - for i, param_group in enumerate(self.optimizer.param_groups): + for i, _ in enumerate(self.optimizer.param_groups): # Link bf16 and fp32 params in partition partition_id = dist.get_rank(group=self.real_dp_process_group[i]) partition_size = self.bf16_groups_flat[i].numel() // dp_world_size - self._link_hp_params(self.bf16_groups[i], - self.fp32_groups_flat_partition[i], - partition_id * partition_size, - partition_size, - self.real_dp_process_group[i]) - - def _init_lp_to_hp_mapping(self, - lp_param_list, - partition_start, - partition_size, - dp_group): - current_offset = 0 - param_and_offset_list = [] - partition_end = partition_start + partition_size - for lp_param in lp_param_list: - lp_param._hp_mapping = None - lp_param._dp_group = dp_group - lp_param.get_full_hp_param = types.MethodType(get_full_hp_param, lp_param) - lp_param.load_hp_checkpoint_state = types.MethodType( - load_hp_checkpoint_state, - lp_param) - # lp_param overlaps with partition if both are true - # 1) current_offset < partition_end, - # 2) current_offset + lp_param.numel() >= partition_start - lp_param_end = current_offset + lp_param.numel() - if current_offset < partition_end and lp_param_end > partition_start: - param_and_offset_list.append((lp_param, current_offset)) - current_offset += lp_param.numel() - - return param_and_offset_list - - def _link_hp_params(self, - lp_param_list, - flat_hp_partition, - partition_start, - partition_size, - dp_group): - local_lp_param_and_offset = self._init_lp_to_hp_mapping( - lp_param_list, - partition_start, - partition_size, - dp_group) - - hp_end = partition_start + partition_size - for lp_param, lp_start in local_lp_param_and_offset: - lp_end = lp_param.numel() + lp_start - hp_start = partition_start - - fragment_start = max(lp_start, hp_start) - fragment_end = min(lp_end, hp_end) - # print( - # f'{self.dp_rank=} {lp_start=} {lp_end-lp_start=} {hp_start=} {hp_end-hp_start=} {fragment_start=} {fragment_end-fragment_start=}' - # ) - assert fragment_start < fragment_end, \ - f'fragment start {fragment_start} should be < fragment_end {fragment_end}' - - fragment_numel = fragment_end - fragment_start - hp_frag_address = fragment_address(start=fragment_start - hp_start, - numel=fragment_numel) - hp_fragment_tensor = flat_hp_partition.narrow(0, - hp_frag_address.start, - hp_frag_address.numel) - - optim_fragment = { - key: value.narrow(0, - hp_frag_address.start, - hp_frag_address.numel) - for key, - value in self.optimizer.state[flat_hp_partition].items() - if torch.is_tensor(value) and value.dim() > 0 - } - - lp_frag_address = fragment_address(start=fragment_start - lp_start, - numel=fragment_numel) - lp_fragment_tensor = lp_param.flatten().narrow(0, - lp_frag_address.start, - lp_frag_address.numel) - - lp_param._hp_mapping = tensor_fragment(lp_fragment=lp_fragment_tensor, - lp_fragment_address=lp_frag_address, - hp_fragment=hp_fragment_tensor, - hp_fragment_address=hp_frag_address, - optim_fragment=optim_fragment) + flat_hp_partition = self.fp32_groups_flat_partition[i] + link_hp_params( + lp_param_list=self.bf16_groups[i], + flat_hp_partition=flat_hp_partition, + partition_start=partition_id * partition_size, + partition_size=partition_size, + partition_optimizer_state=self.optimizer.state[flat_hp_partition], + dp_group=self.real_dp_process_group[i]) def initialize_optimizer_states(self): """Take an optimizer step with zero-valued gradients to allocate internal diff --git a/deepspeed/utils/__init__.py b/deepspeed/utils/__init__.py index 5e05bf46e9b6..6dd805b37844 100644 --- a/deepspeed/utils/__init__.py +++ b/deepspeed/utils/__init__.py @@ -4,4 +4,6 @@ from .init_on_device import OnDevice from .groups import * from .nvtx import instrument_w_nvtx +from .tensor_fragment import tensor_fragment, get_full_hp_param, get_hp_fragment_mapping +from .mixed_precision_linkage import link_hp_params from deepspeed.runtime.dataloader import RepeatingLoader diff --git a/deepspeed/utils/mixed_precision_linkage.py b/deepspeed/utils/mixed_precision_linkage.py new file mode 100644 index 000000000000..bfd9932b8d7a --- /dev/null +++ b/deepspeed/utils/mixed_precision_linkage.py @@ -0,0 +1,45 @@ +""" +Copyright 2022 The Microsoft DeepSpeed Team +""" +import types +from deepspeed.utils import get_full_hp_param, get_hp_fragment_mapping + + +def link_hp_params(lp_param_list, + flat_hp_partition, + partition_start, + partition_size, + partition_optimizer_state, + dp_group): + local_lp_param_and_offset = _init_lp_to_hp_mapping(lp_param_list, + partition_start, + partition_size, + dp_group) + + for lp_param, lp_start in local_lp_param_and_offset: + lp_param._hp_mapping = get_hp_fragment_mapping(lp_param, + lp_start, + flat_hp_partition, + partition_start, + partition_size, + partition_optimizer_state) + + +def _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_group): + current_offset = 0 + param_and_offset_list = [] + partition_end = partition_start + partition_size + for lp_param in lp_param_list: + lp_param._hp_mapping = None + lp_param._dp_group = dp_group + lp_param.get_full_hp_param = types.MethodType(get_full_hp_param, lp_param) + + # lp_param overlaps with partition if both are true + # 1) current_offset < partition_end, + # 2) current_offset + lp_param.numel() >= partition_start + lp_param_end = current_offset + lp_param.numel() + if current_offset < partition_end and lp_param_end > partition_start: + param_and_offset_list.append((lp_param, current_offset)) + current_offset += lp_param.numel() + + return param_and_offset_list diff --git a/deepspeed/utils/tensor_fragment.py b/deepspeed/utils/tensor_fragment.py new file mode 100644 index 000000000000..913b188df9a9 --- /dev/null +++ b/deepspeed/utils/tensor_fragment.py @@ -0,0 +1,105 @@ +""" +Copyright 2022 The Microsoft DeepSpeed Team +""" + +import torch +from dataclasses import dataclass +from deepspeed import comm as dist + + +@dataclass +class fragment_address: + numel: int + start: int + + +@dataclass +class tensor_fragment: + lp_fragment: torch.Tensor + lp_fragment_address: fragment_address + hp_fragment: torch.Tensor + hp_fragment_address: fragment_address + optim_fragment: {} + + def update_hp(self): + self.hp_fragment.data.copy_(self.lp_fragment.data) + + def update_lp(self): + self.lp_fragment.data.copy_(self.hp_fragment.data) + + def get_optim_state_fragment(self, key): + if key in self.optim_fragment: + return self.optim_fragment[key] + else: + raise ValueError(f'{key} not found in optimizer state fragment') + + def get_hp_fragment_address(self): + return self.hp_fragment_address + + def get_optim_state_keys(self): + return list(self.optim_fragment.keys()) + + +def get_full_hp_param(self, optim_state_key=None): + reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten() + if self._hp_mapping is not None: + lp_frag_address = self._hp_mapping.lp_fragment_address + reduce_fragment = torch.narrow(reduce_buffer, + 0, + lp_frag_address.start, + lp_frag_address.numel) + if optim_state_key is None: + hp_fragment = self._hp_mapping.hp_fragment + else: + hp_fragment = self._hp_mapping.get_optim_state_fragment(optim_state_key) + + reduce_fragment.data.copy_(hp_fragment.data) + dist.all_reduce(reduce_buffer, group=self._dp_group) + return reduce_buffer.reshape_as(self) + + +def get_hp_fragment_mapping(lp_param, + lp_start, + flat_hp_partition, + partition_start, + partition_size, + optimizer_state_dict): + lp_end = lp_param.numel() + lp_start + hp_start = partition_start + hp_end = partition_start + partition_size + + fragment_start = max(lp_start, hp_start) + fragment_end = min(lp_end, hp_end) + # print( + # f'{self.dp_rank=} {lp_start=} {lp_end-lp_start=} {hp_start=} {hp_end-hp_start=} {fragment_start=} {fragment_end-fragment_start=}' + # ) + assert fragment_start < fragment_end, \ + f'fragment start {fragment_start} should be < fragment_end {fragment_end}' + + fragment_numel = fragment_end - fragment_start + hp_frag_address = fragment_address(start=fragment_start - hp_start, + numel=fragment_numel) + hp_fragment_tensor = flat_hp_partition.narrow(0, + hp_frag_address.start, + hp_frag_address.numel) + + optim_fragment = { + key: value.narrow(0, + hp_frag_address.start, + hp_frag_address.numel) + for key, + value in optimizer_state_dict.items() + if torch.is_tensor(value) and value.dim() > 0 + } + + lp_frag_address = fragment_address(start=fragment_start - lp_start, + numel=fragment_numel) + lp_fragment_tensor = lp_param.flatten().narrow(0, + lp_frag_address.start, + lp_frag_address.numel) + + return tensor_fragment(lp_fragment=lp_fragment_tensor, + lp_fragment_address=lp_frag_address, + hp_fragment=hp_fragment_tensor, + hp_fragment_address=hp_frag_address, + optim_fragment=optim_fragment)