Skip to content

Commit

Permalink
Support universal checkpoint for GPTModel (deepspeedai#361)
Browse files Browse the repository at this point in the history
Save to checkpoints the required universal patterns for GPTModel.

Additionally, unify the logic of universal checkpoint info for both GPTModel
and GPTModelPipe under a new class: UniversalCheckpointInfo.

Signed-off-by: Moshe Island <misland@habana.ai>
Co-authored-by: Moshe Island <misland@habana.ai>
  • Loading branch information
mosheisland and mosheisland authored Mar 10, 2024
1 parent a9856ce commit df0e2e4
Showing 1 changed file with 115 additions and 111 deletions.
226 changes: 115 additions & 111 deletions megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,119 @@ def post_language_model_processing(lm_output, labels, logit_weights,
return loss


class UniversalCheckpointInfo:
def __init__(self, using_model_pipe: bool):
self.using_model_pipe = using_model_pipe
self.args = get_args()
self.info = self._build_universal_checkpoint_info()

def get(self):
return self.info

def _build_universal_checkpoint_info(self):
info = dict()
if DS_UNIVERSAL_CHECKPOINT_INFO:
# Vocabulary parameters (embeddings) that require special handling due to padding.
info[VOCABULARY_PARAMETER_PATTERNS] = self._get_vocab_param_patterns()

if self.using_model_pipe:
# Replicated (shared) parameters on the pipeline dimension
info[PIPELINE_REPLICATED_PARAMETER_PATTERNS] = self._get_pp_replicated_param_patterns()

if self.args.tensor_model_parallel_size > 1:
# Parameter slices that should be averaged not concatenated.
info[TP_REPLICATED_PARAMETER_PATTERNS] = self._get_tp_replicated_param_patterns()

# Parameter that are sliced on the row dimension
info[PARAMETER_WITH_ROW_PARALLELISM_PATTERNS] = self._get_row_parallel_param_patterns()

# SWIGLU parameters are first sliced on dim=0 to tp slices
# Then, each tp slice is chunked into 2 to create the linear layers L1, L2 used for silu(L1(x)) * L2(x))
info[PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0] = self._get_swiglu_col_parallel_param_patterns()
return info

def _get_vocab_param_patterns(self):
if self.using_model_pipe:
if self.args.untie_embeddings_and_output_weights:
patterns = [
r"\d+.word_embeddings.weight",
r"\d+.lm_head.weight"
]
else:
patterns = [
r"tied_modules.embed.word_embeddings.weight"
]
else:
patterns = [
"language_model.embedding.word_embeddings.weight"
]
if self.args.untie_embeddings_and_output_weights:
patterns.append("language_model.output_layer.weight")
return patterns

def _get_pp_replicated_param_patterns(self):
if self.args.untie_embeddings_and_output_weights:
return []
patterns = self._get_vocab_param_patterns()
if self.args.add_position_embedding:
patterns.append(r"tied_modules.embed.position_embeddings.weight")
return patterns

def _layers_prefix(self):
return "" if self.using_model_pipe else "language_model.encoder.layers."

def _get_tp_replicated_param_patterns(self):
layers_prefix = self._layers_prefix()
patterns = [
layers_prefix + r"\d+.input_layernorm.weight",
layers_prefix + r"\d+.post_attention_layernorm.weight",
]
# Add final normalization layer
final_norm_w_pattern = r"\d+.weight" if self.using_model_pipe \
else "language_model.encoder.final_layernorm.weight"
patterns.append(final_norm_w_pattern)
if self.args.normalization == 'layernorm':
final_norm_b_pattern = r"\d+.bias" if self.using_model_pipe \
else "language_model.encoder.final_layernorm.bias"
patterns.append(final_norm_b_pattern)
# add Positional Embedding
if self.args.add_position_embedding:
pos_emb_pattern = "tied_modules.embed.position_embeddings.weight" if self.using_model_pipe \
else "language_model.embedding.position_embeddings.weight"
patterns.append(pos_emb_pattern)
# add Linear bias
if self.args.add_bias_linear:
patterns.extend([
layers_prefix + r"\d+.self_attention.dense.bias",
layers_prefix + r"\d+.mlp.dense_4h_to_h.bias",
])
# add LN bias
if self.args.normalization == 'layernorm':
patterns.extend([
layers_prefix + r"\d+.input_layernorm.bias",
layers_prefix + r"\d+.post_attention_layernorm.bias",
])
return patterns

def _get_row_parallel_param_patterns(self):
layers_prefix = self._layers_prefix()
return [
layers_prefix + r"\d+.mlp.dense_4h_to_h.weight",
layers_prefix + r"\d+.self_attention.dense.weight",
]

def _get_swiglu_col_parallel_param_patterns(self):
if not self.args.swiglu:
return []
layers_prefix = self._layers_prefix()
patterns = [
layers_prefix + r"\d+.mlp.dense_h_to_4h.weight",
]
if self.args.add_bias_linear:
patterns.append(layers_prefix + r"\d+.mlp.dense_h_to_4h.bias")
return patterns


class GPTModel(MegatronModule):
"""GPT-2 Language model."""

Expand Down Expand Up @@ -177,36 +290,10 @@ def load_state_dict(self, state_dict, strict=True):
state_dict["moe_state_dict"] = moe_state_dict
self.language_model.load_state_dict(state_dict, strict=strict)

def _get_vocab_param_patterns(self):
args = get_args()
if args.untie_embeddings_and_output_weights:
patterns = [
r"\d+.word_embeddings.weight",
r"\d+.lm_head.weight"
]
else:
patterns = [
r"tied_modules.embed.word_embeddings.weight"
]
return patterns

def universal_checkpoint_info(self):
info = dict()
args = get_args()
return UniversalCheckpointInfo(using_model_pipe=False).get()

if DS_UNIVERSAL_CHECKPOINT_INFO:
# Vocabulary parameters (embeddings) that require special handling due to padding.
info[VOCABULARY_PARAMETER_PATTERNS] = self._get_vocab_param_patterns()

if args.tensor_model_parallel_size > 1:
# Parameter slices that should be averaged not concatenated.
info[TP_REPLICATED_PARAMETER_PATTERNS] = self._get_tp_replicated_param_patterns()

# Parameter that are sliced on the row dimension
info[PARAMETER_WITH_ROW_PARALLELISM_PATTERNS] = self._get_row_parallel_param_patterns()

return info

def CrossEntropy(output, labels):
labels, loss_mask = labels[0], labels[1]

Expand Down Expand Up @@ -336,88 +423,5 @@ def _logits_helper(embedding, lm_output):
activation_checkpoint_interval=interval,
partition_method='type:transformer')

@staticmethod
def _get_vocab_param_patterns():
args = get_args()
if args.untie_embeddings_and_output_weights:
patterns = [
r"\d+.word_embeddings.weight",
r"\d+.lm_head.weight"
]
else:
patterns = [
r"tied_modules.embed.word_embeddings.weight"
]
return patterns

def _get_pp_replicated_param_patterns(self):
args = get_args()
if args.untie_embeddings_and_output_weights:
return []
patterns = self._get_vocab_param_patterns()
if args.add_position_embedding:
patterns.append(r"tied_modules.embed.position_embeddings.weight")
return patterns

@staticmethod
def _get_tp_replicated_param_patterns():
args = get_args()
patterns = [
r"\d+.input_layernorm.weight",
r"\d+.post_attention_layernorm.weight",
r"\d+.weight",
]
if args.add_position_embedding:
patterns.append(r"tied_modules.embed.position_embeddings.weight")
if args.add_bias_linear:
patterns.extend([
r"\d+.self_attention.dense.bias",
r"\d+.mlp.dense_4h_to_h.bias",
])
if args.normalization == 'layernorm':
patterns.extend([
r"\d+.input_layernorm.bias",
r"\d+.post_attention_layernorm.bias",
r"\d+.bias",
])
return patterns

@staticmethod
def _get_row_parallel_param_patterns():
return [
r"\d+.mlp.dense_4h_to_h.weight",
r"\d+.self_attention.dense.weight",
]

@staticmethod
def _get_swiglu_col_parallel_param_patterns():
args = get_args()
if not args.swiglu:
return []
patterns = [
r"\d+.mlp.dense_h_to_4h.weight",
]
if args.add_bias_linear:
patterns.append(r"\d+.mlp.dense_h_to_4h.bias")
return patterns


def universal_checkpoint_info(self):
info = dict()
if DS_UNIVERSAL_CHECKPOINT_INFO:
# Vocabulary parameters (embeddings) that require special handling due to padding.
info[VOCABULARY_PARAMETER_PATTERNS] = self._get_vocab_param_patterns()

# Replicated (shared) parameters on the pipeline dimension
info[PIPELINE_REPLICATED_PARAMETER_PATTERNS] = self._get_pp_replicated_param_patterns()

# Parameter slices that should be averaged not concatenated.
info[TP_REPLICATED_PARAMETER_PATTERNS] = self._get_tp_replicated_param_patterns()

# Parameter that are sliced on the row dimension
info[PARAMETER_WITH_ROW_PARALLELISM_PATTERNS] = self._get_row_parallel_param_patterns()

# SWIGLU parameters are first sliced on dim=0 to tp slices
# Then, each tp slice is chunked into 2 to create the linear layers L1, L2 used for silu(L1(x)) * L2(x))
info[PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0] = self._get_swiglu_col_parallel_param_patterns()
return info
return UniversalCheckpointInfo(using_model_pipe=True).get()

0 comments on commit df0e2e4

Please sign in to comment.