Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ def parse_args():
parser.add_argument('--offload',
action='store_true',
help='Enable ZeRO Offload techniques.')
parser.add_argument('--dtype', type=str, default='fp16',
choices=['fp16', 'bf16'],
help = 'Training data type')
parser.add_argument(
'--zero_stage',
type=int,
Expand Down Expand Up @@ -202,6 +205,7 @@ def main():
args.global_rank = torch.distributed.get_rank()

ds_config = get_train_ds_config(offload=args.offload,
dtype=args.dtype,
stage=args.zero_stage,
enable_tensorboard=args.enable_tensorboard,
tb_path=args.tensorboard_path,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ def parse_args():
parser.add_argument('--offload',
action='store_true',
help='Enable ZeRO Offload techniques.')
parser.add_argument('--dtype', type=str, default='fp16',
choices=['fp16', 'bf16'],
help = 'Training data type')
parser.add_argument(
'--zero_stage',
type=int,
Expand Down Expand Up @@ -197,6 +200,7 @@ def main():
args.global_rank = torch.distributed.get_rank()

ds_config = get_train_ds_config(offload=args.offload,
dtype=args.dtype,
stage=args.zero_stage,
enable_tensorboard=args.enable_tensorboard,
tb_path=args.tensorboard_path,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ def parse_args():
parser.add_argument('--offload',
action='store_true',
help='Enable ZeRO Offload techniques.')
parser.add_argument('--dtype', type=str, default='fp16',
choices=['fp16', 'bf16'],
help = 'Training data type')
parser.add_argument(
'--offload_reference_model',
action='store_true',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def _init_actor(self, actor_model_name_or_path):
# DS Config
ds_config = get_train_ds_config(
offload=self.args.offload,
dtype=self.args.dtype,
stage=self.args.actor_zero_stage,
enable_hybrid_engine=self.args.enable_hybrid_engine,
inference_tp_size=self.args.inference_tp_size,
Expand Down Expand Up @@ -139,6 +140,7 @@ def _init_ref(self, actor_model_name_or_path):
# If actor is ZeRO-3 then we use it for everything, otherwise assume we have enough memory for ref model
zero_stage = 0
ds_config = get_eval_ds_config(self.args.offload_reference_model,
self.args.dtype,
zero_stage)
ds_config[
'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size
Expand All @@ -165,6 +167,7 @@ def _init_ema(self, actor_model_name_or_path):
# If actor is ZeRO-3 then we use it for everything, otherwise assume we have enough memory
zero_stage = 0
ds_config = get_eval_ds_config(self.args.offload_reference_model,
self.args.dtype,
zero_stage)
ds_config[
'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size
Expand All @@ -191,6 +194,7 @@ def _init_critic(self, critic_model_name_or_path):
stime = log_init("Critic")
ds_config = get_train_ds_config(
offload=self.args.offload,
dtype=self.args.dtype,
stage=self.args.critic_zero_stage,
enable_tensorboard=self.args.enable_tensorboard,
tb_path=self.args.tensorboard_path,
Expand All @@ -203,6 +207,7 @@ def _init_critic(self, critic_model_name_or_path):
) * self.args.gradient_accumulation_steps

ds_eval_config = get_eval_ds_config(offload=False,
dtype=self.args.dtype,
stage=self.args.critic_zero_stage)
# We need to set train batch size and micro batch size here to pass the sanity check of DeepSpeed engine.
ds_eval_config[
Expand Down Expand Up @@ -266,14 +271,15 @@ def _init_reward(self, critic_model_name_or_path):
zero_stage = 0

ds_config = get_eval_ds_config(offload=self.args.offload,
dtype=self.args.dtype,
stage=zero_stage)
ds_config[
'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size
ds_config[
'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size(
) * self.args.gradient_accumulation_steps

ds_eval_config = get_eval_ds_config(offload=False, stage=zero_stage)
ds_eval_config = get_eval_ds_config(offload=False, dtype=self.args.dtype, stage=zero_stage)

# We need to set train batch size and micro batch size here to pass the sanity check of DeepSpeed engine.
ds_eval_config[
Expand Down
33 changes: 25 additions & 8 deletions applications/DeepSpeed-Chat/training/utils/ds_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@


def get_train_ds_config(offload,
dtype,
stage=2,
enable_hybrid_engine=False,
inference_tp_size=1,
Expand All @@ -25,6 +26,17 @@ def get_train_ds_config(offload,
tb_name=""):

device = "cpu" if offload else "none"
if dtype == "fp16":
data_type = "fp16"
dtype_config = {
"enabled": True,
"loss_scale_window": 100
}
elif dtype == "bf16":
data_type = "bfloat16"
dtype_config = {
"enabled": True
}
zero_opt_dict = {
"stage": stage,
"offload_param": {
Expand All @@ -48,10 +60,7 @@ def get_train_ds_config(offload,
"train_micro_batch_size_per_gpu": MICRO_BATCH_SIZE,
"steps_per_print": 10,
"zero_optimization": zero_opt_dict,
"fp16": {
"enabled": True,
"loss_scale_window": 100
},
data_type: dtype_config,
"gradient_clipping": 1.0,
"prescale_gradients": False,
"wall_clock_breakdown": False,
Expand All @@ -71,8 +80,18 @@ def get_train_ds_config(offload,
}


def get_eval_ds_config(offload, stage=0):
def get_eval_ds_config(offload, dtype, stage=0):
device = "cpu" if offload else "none"
if dtype == "fp16":
data_type = "fp16"
dtype_config = {
"enabled": True,
}
elif dtype == "bf16":
data_type = "bfloat16"
dtype_config = {
"enabled": True
}
zero_opt_dict = {
"stage": stage,
"stage3_param_persistence_threshold": 1e4,
Expand All @@ -86,9 +105,7 @@ def get_eval_ds_config(offload, stage=0):
"train_micro_batch_size_per_gpu": MICRO_BATCH_SIZE,
"steps_per_print": 10,
"zero_optimization": zero_opt_dict,
"fp16": {
"enabled": True
},
data_type: dtype_config,
"gradient_clipping": 1.0,
"prescale_gradients": False,
"wall_clock_breakdown": False
Expand Down