Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from .reward_model import RewardModel
from ..utils import load_state_dict_into_model
from ..utils import print_rank_0


def configure_dropout(model_config, dropout):
Expand Down Expand Up @@ -130,8 +131,8 @@ def create_critic_model(model_name_or_path,
critic_model = create_hf_model(AutoModel, model_name_or_path, tokenizer,
ds_config, rlhf_training, dropout)
end = time.time()
if torch.distributed.get_rank() == 0:
print(f"> Creating model from_config took {end - start} seconds")
print_rank_0(f">Creating model from_config took {end - start} seconds",
None)

critic_model = RewardModel(
critic_model,
Expand All @@ -152,8 +153,8 @@ def create_critic_model(model_name_or_path,
start = time.time()
model_ckpt_state_dict = torch.load(model_ckpt_path, map_location='cpu')
end = time.time()
if torch.distributed.get_rank() == 0:
print(f"> torch.load took {end - start} seconds")
print_rank_0(f">Creating model from_config took {end - start} seconds",
None)

# load critic model from checkpoint with zero-stage 3 compatibility
# this functionality may be moved to DS checkpoint load API in future
Expand All @@ -163,7 +164,8 @@ def create_critic_model(model_name_or_path,
"",
zero_stage=zero_stage)
end = time.time()
if torch.distributed.get_rank() == 0:
print(f"> Loading model state dict took {end - start} seconds")

print_rank_0(f">Creating model from_config took {end - start} seconds",
None)

return critic_model
17 changes: 15 additions & 2 deletions applications/DeepSpeed-Chat/training/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,22 @@
import torch.nn as nn


def print_rank_0(msg, rank=0):
if rank <= 0:
def print_rank_0(msg, rank=None):
if rank is not None and rank <= 0:
print(msg)
elif is_rank_0():
print(msg)


def is_rank_0():
"""Check whether it is rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
return True
else:
return False
else:
return True


def to_device(batch, device):
Expand Down