Skip to content

Commit

Permalink
Update megatron/training.py
Browse files Browse the repository at this point in the history
  • Loading branch information
saforem2 committed Apr 4, 2024
1 parent 8ac8bdc commit 8c6c91f
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,17 @@ def _create_ds_config_dict():
return ds_config_dict


def pretrain(train_valid_test_dataset_provider,
model_provider,
model_type,
forward_step_func,
process_non_loss_data_func=None,
extra_args_provider=None,
args_defaults={},
data_post_process=None,
external_args={}):
def pretrain(
train_valid_test_dataset_provider,
model_provider,
model_type,
forward_step_func,
process_non_loss_data_func=None,
extra_args_provider=None,
args_defaults={},
data_post_process=None,
external_args={},
) -> torch.nn.Module:
"""Main training program.
This function will run the followings in the order provided:
Expand Down Expand Up @@ -149,6 +151,9 @@ def pretrain(train_valid_test_dataset_provider,
to it. It is used for programs to add their own arguments.
args_defaults: a dictionary from argument-name to argument-value. It
to set already parse arguments.
Returns:
model (torch.nn.Module)
"""

# Initalize and get arguments, timers, and Tensorboard writer.
Expand Down

0 comments on commit 8c6c91f

Please sign in to comment.