diff --git a/applications/DeepSpeed-Chat/training/utils/model/model_utils.py b/applications/DeepSpeed-Chat/training/utils/model/model_utils.py index 508b58317..147ef9678 100644 --- a/applications/DeepSpeed-Chat/training/utils/model/model_utils.py +++ b/applications/DeepSpeed-Chat/training/utils/model/model_utils.py @@ -14,6 +14,7 @@ from .reward_model import RewardModel from ..utils import load_state_dict_into_model +from ..utils import print_rank_0 def configure_dropout(model_config, dropout): @@ -130,8 +131,8 @@ def create_critic_model(model_name_or_path, critic_model = create_hf_model(AutoModel, model_name_or_path, tokenizer, ds_config, rlhf_training, dropout) end = time.time() - if torch.distributed.get_rank() == 0: - print(f"> Creating model from_config took {end - start} seconds") + print_rank_0(f">Creating model from_config took {end - start} seconds", + None) critic_model = RewardModel( critic_model, @@ -152,8 +153,8 @@ def create_critic_model(model_name_or_path, start = time.time() model_ckpt_state_dict = torch.load(model_ckpt_path, map_location='cpu') end = time.time() - if torch.distributed.get_rank() == 0: - print(f"> torch.load took {end - start} seconds") + print_rank_0(f">Creating model from_config took {end - start} seconds", + None) # load critic model from checkpoint with zero-stage 3 compatibility # this functionality may be moved to DS checkpoint load API in future @@ -163,7 +164,8 @@ def create_critic_model(model_name_or_path, "", zero_stage=zero_stage) end = time.time() - if torch.distributed.get_rank() == 0: - print(f"> Loading model state dict took {end - start} seconds") + + print_rank_0(f">Creating model from_config took {end - start} seconds", + None) return critic_model diff --git a/applications/DeepSpeed-Chat/training/utils/utils.py b/applications/DeepSpeed-Chat/training/utils/utils.py index b5cfb8d6d..78180235e 100644 --- a/applications/DeepSpeed-Chat/training/utils/utils.py +++ b/applications/DeepSpeed-Chat/training/utils/utils.py @@ -14,9 +14,22 @@ import torch.nn as nn -def print_rank_0(msg, rank=0): - if rank <= 0: +def print_rank_0(msg, rank=None): + if rank is not None and rank <= 0: print(msg) + elif is_rank_0(): + print(msg) + + +def is_rank_0(): + """Check whether it is rank 0.""" + if torch.distributed.is_initialized(): + if torch.distributed.get_rank() == 0: + return True + else: + return False + else: + return True def to_device(batch, device):