Skip to content

Commit

Permalink
Merge branch 'main' into shashankv-readme-update
Browse files Browse the repository at this point in the history
  • Loading branch information
shashank3959 authored Apr 5, 2024
2 parents d7d2514 + cf3b3a5 commit a8be773
Showing 1 changed file with 22 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -770,8 +770,28 @@ def get_config_arg(key: str, default_value: Optional[Any] = None) -> Any:

# Make sure embedding grad reductions are in FP32
if optim_dtype == torch.float32:
for name, param in self.named_parameters():
if 'word_embedding' in name or 'position_embedding' in name or 'output_layer' in name:
fp32_params = []
modules = self.get_model_module_list()
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
if self.mcore_gpt:
fp32_params.append(modules[0].shared_embedding_or_output_weight())
fp32_params.append(modules[0].embedding.position_embeddings.weight)
else:
fp32_params.append(modules[0].word_embeddings_weight())
fp32_params.append(modules[0].position_embeddings_weight())
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
share_embeddings_and_output_weights = (
modules[-1].share_embeddings_and_output_weights
if self.mcore_gpt
else modules[-1].share_token_embeddings
)
if share_embeddings_and_output_weights:
if self.mcore_gpt:
fp32_params.append(modules[-1].shared_embedding_or_output_weight())
else:
fp32_params.append(modules[-1].word_embeddings_weight())
for param in fp32_params:
if param is not None:
param._with_fp32_optimizer = True

# Match param allgather with model dtype
Expand Down

0 comments on commit a8be773

Please sign in to comment.