Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
0b30bcb
FlexGen reference (#730)
tjruwase Sep 13, 2023
bae2afb
Update Llama check to use module instead of model (#734)
lekurile Sep 14, 2023
27b60d2
Filter transformers version 4.33.2 due to bug (#735)
lekurile Sep 19, 2023
9c94044
Update README.md (#739)
NinoRisteski Sep 22, 2023
db56381
support DeepSpeedChat to run on different device besides cuda (#736)
ys950902 Sep 22, 2023
9b3d898
support bf16 for RLHF training (#733)
ys950902 Sep 22, 2023
d8f3f73
deepspeed-chat: support any model in chatbot (#744)
mosheisland Oct 2, 2023
58e4e9c
Fix padding and dtype issues (#738)
tjruwase Oct 2, 2023
2f99dcd
deepspeed-chat: handle overflow for bf16_optimizer (#745)
mosheisland Oct 3, 2023
4bf1924
deepspeed-chat: support explicit configuration of dropout (#746)
mosheisland Oct 3, 2023
ca03bd7
deepspeed-chat: fix incorrect lr when using lora only (#756)
mosheisland Oct 3, 2023
0d11c63
Add default value for tokenizer path (#699)
xu-song Oct 3, 2023
ca41e8b
support `trust_remote_code` in inference test (#709)
wangruohui Oct 3, 2023
6c05e03
Deepspeed-VisualChat (#753)
yaozhewei Oct 3, 2023
4364031
Update README.md (#757)
xiaoxiawu-microsoft Oct 3, 2023
e6f400a
deepspeed-chat: calculate loss in fp32 (#754)
mosheisland Oct 4, 2023
bfad08f
deepspeed-chat: support periodic eval in stage2 (#747)
mosheisland Oct 4, 2023
10aef97
add the path to load the local dataset (#761)
ys950902 Oct 8, 2023
0855679
Fix typo (#749)
xu-song Oct 9, 2023
1ba50ed
Resolving epochs being hard-coded (#759)
PareesaMS Oct 11, 2023
3517c6d
Resolves the issue with evaluation on step2 for single GPU (#766)
PareesaMS Oct 12, 2023
5161c0f
deepspeed-chat: train v_head when only optimizing lora (#758)
mosheisland Oct 16, 2023
8d850ba
deepspeed-chat: fix weight decay configuration (#755)
mosheisland Oct 16, 2023
185e25c
deepspeed-chat: fix bf16 stage2 accuracy for bloom-560m (#772)
mosheisland Oct 17, 2023
f7ff9dd
deepspeed-chat: fix training stage1 ppl calculation (#773)
mosheisland Oct 17, 2023
e8d879e
deepspeed-chat: add end-of-text special token (#775)
mosheisland Oct 17, 2023
262ec5c
merge main into python_package
santacml Oct 18, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions applications/DeepSpeed-Chat/deepspeed_chat/rlhf/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
import time
import deepspeed
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.accelerator import get_accelerator

from deepspeed_chat.utils.utils import print_rank_0


def print_all_ranks(tag, value, rank):
world_size = torch.distributed.get_world_size()
all_tensor = torch.zeros(world_size, dtype=torch.float32).cuda()
all_tensor = torch.zeros(world_size, dtype=torch.float32).to(
get_accelerator().current_device_name())
all_tensor[rank] = value
torch.distributed.all_reduce(all_tensor, op=torch.distributed.ReduceOp.SUM)
print_rank_0(f'{tag} {all_tensor}', rank)
Expand Down Expand Up @@ -53,6 +55,7 @@ def __init__(self, rlhf_engine, args):
self.end_of_conversation_token_id = self.tokenizer(
args.end_of_conversation_token)['input_ids'][-1]
self.z3_enabled = args.actor_zero_stage == 3
self.compute_fp32_loss = self.args.compute_fp32_loss

# Those value can be changed
self.kl_ctl = 0.1
Expand All @@ -70,7 +73,7 @@ def _generate_sequence(self, prompts, mask, step):
# This has been added due to a probability/nan error that happens after
# meta-llama/Llama-2-7b-hf enabled do_sample:
# https://huggingface.co/meta-llama/Llama-2-7b-hf/commit/6fdf2e60f86ff2481f2241aaee459f85b5b0bbb9
if self.actor_model.model.config.model_type == "llama":
if self.actor_model.module.config.model_type == "llama":
kwargs = dict(do_sample=False)
else:
kwargs = dict()
Expand Down Expand Up @@ -132,6 +135,9 @@ def generate_experience(self, prompts, mask, step):

logits = output.logits
logits_ref = output_ref.logits
if self.compute_fp32_loss:
logits = logits.to(torch.float)
logits_ref = logits_ref.to(torch.float)

self.generate_time = generate_end - generate_start

Expand Down Expand Up @@ -237,6 +243,11 @@ def train_rlhf(self, inputs):
return actor_loss, critic_loss

def get_overflow(self):
# Overflow is not expected when using bf16
# Therefore, DeepSpeed's BF16_Optimizer does not maintain an overflow indication
if self.args.dtype == "bf16":
return False, False

actor_overflow = self.actor_model.optimizer.overflow
critic_overflow = self.critic_model.optimizer.overflow

Expand All @@ -259,6 +270,9 @@ def critic_loss_fn(self, values, old_values, returns, mask):
old_values - self.cliprange_value,
old_values + self.cliprange_value,
)
if self.compute_fp32_loss:
values = values.float()
values_clipped = values_clipped.float()
vf_loss1 = (values - returns)**2
vf_loss2 = (values_clipped - returns)**2
vf_loss = 0.5 * torch.sum(
Expand Down
18 changes: 12 additions & 6 deletions applications/DeepSpeed-Chat/deepspeed_chat/rlhf/rlhf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def _init_actor(self, actor_model_name_or_path):
# DS Config
ds_config = get_train_ds_config(
offload=self.args.offload,
dtype=self.args.dtype,
stage=self.args.actor_zero_stage,
enable_hybrid_engine=self.args.enable_hybrid_engine,
inference_tp_size=self.args.inference_tp_size,
Expand All @@ -91,7 +92,7 @@ def _init_actor(self, actor_model_name_or_path):
model_name_or_path=actor_model_name_or_path,
tokenizer=self.tokenizer,
ds_config=ds_config,
disable_dropout=self.args.disable_actor_dropout)
dropout=self.args.actor_dropout)

# LoRA
if self.args.actor_lora_dim > 0:
Expand Down Expand Up @@ -139,7 +140,7 @@ def _init_ref(self, actor_model_name_or_path):
# If actor is ZeRO-3 then we use it for everything, otherwise assume we have enough memory for ref model
zero_stage = 0
ds_config = get_eval_ds_config(self.args.offload_reference_model,
zero_stage)
self.args.dtype, zero_stage)
ds_config[
'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size
#TODO(jeff): we should probably set grad accumlation steps here as well for clarity
Expand All @@ -165,7 +166,7 @@ def _init_ema(self, actor_model_name_or_path):
# If actor is ZeRO-3 then we use it for everything, otherwise assume we have enough memory
zero_stage = 0
ds_config = get_eval_ds_config(self.args.offload_reference_model,
zero_stage)
self.args.dtype, zero_stage)
ds_config[
'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size
#TODO(jeff): we should probably set grad accumlation steps here as well for clarity
Expand All @@ -191,6 +192,7 @@ def _init_critic(self, critic_model_name_or_path):
stime = log_init("Critic")
ds_config = get_train_ds_config(
offload=self.args.offload,
dtype=self.args.dtype,
stage=self.args.critic_zero_stage,
enable_tensorboard=self.args.enable_tensorboard,
tb_path=self.args.tensorboard_path,
Expand All @@ -203,6 +205,7 @@ def _init_critic(self, critic_model_name_or_path):
) * self.args.gradient_accumulation_steps

ds_eval_config = get_eval_ds_config(offload=False,
dtype=self.args.dtype,
stage=self.args.critic_zero_stage)
# We need to set train batch size and micro batch size here to pass the sanity check of DeepSpeed engine.
ds_eval_config[
Expand All @@ -218,7 +221,7 @@ def _init_critic(self, critic_model_name_or_path):
ds_config=ds_eval_config,
num_padding_at_beginning=self.args.num_padding_at_beginning,
rlhf_training=True,
disable_dropout=self.args.disable_critic_dropout,
dropout=self.args.critic_dropout,
zero_stage=self.args.critic_zero_stage)

# LoRA
Expand Down Expand Up @@ -266,14 +269,17 @@ def _init_reward(self, critic_model_name_or_path):
zero_stage = 0

ds_config = get_eval_ds_config(offload=self.args.offload,
dtype=self.args.dtype,
stage=zero_stage)
ds_config[
'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size
ds_config[
'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size(
) * self.args.gradient_accumulation_steps

ds_eval_config = get_eval_ds_config(offload=False, stage=zero_stage)
ds_eval_config = get_eval_ds_config(offload=False,
dtype=self.args.dtype,
stage=zero_stage)

# We need to set train batch size and micro batch size here to pass the sanity check of DeepSpeed engine.
ds_eval_config[
Expand All @@ -289,7 +295,7 @@ def _init_reward(self, critic_model_name_or_path):
ds_config=ds_eval_config,
num_padding_at_beginning=self.args.num_padding_at_beginning,
rlhf_training=True,
disable_dropout=self.args.disable_critic_dropout,
dropout=self.args.critic_dropout,
zero_stage=zero_stage)

reward_engine, *_ = deepspeed.initialize(model=reward_model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import hashlib
from itertools import chain
from deepspeed_chat.utils.data import raw_datasets
from deepspeed.accelerator import get_accelerator


def get_raw_dataset(dataset_name, output_path, seed, local_rank):
Expand Down Expand Up @@ -281,7 +282,8 @@ def create_prompt_dataset(local_rank,
eval_fname = f"{output_path}/evaldata_{fname}.pt"

cache_found = os.path.isfile(train_fname) and os.path.isfile(eval_fname)
buf_create_cache = torch.ByteTensor([not cache_found]).cuda()
buf_create_cache = torch.ByteTensor([not cache_found]).to(
get_accelerator().current_device_name())
torch.distributed.all_reduce(buf_create_cache)

if local_rank <= 0 and (buf_create_cache.item() != 0 or reload):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

import os
# DeepSpeed Team
from datasets import load_dataset
from datasets import load_dataset, load_from_disk
from torch.utils.data import Subset
import re

Expand All @@ -15,7 +16,9 @@ def __init__(self, output_path, seed, local_rank, dataset_name):
self.output_path = output_path
self.seed = seed
self.local_rank = local_rank
if not dataset_name == 'local/jsonfile':
if os.path.exists(dataset_name):
self.raw_datasets = load_from_disk(dataset_name)
elif not dataset_name == 'local/jsonfile':
self.raw_datasets = load_dataset(dataset_name)

def get_train_data(self):
Expand Down
34 changes: 22 additions & 12 deletions applications/DeepSpeed-Chat/deepspeed_chat/utils/ds_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@

# DeepSpeed Team

import torch
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator

GLOBAL_BATCH_SIZE = 32
MICRO_BATCH_SIZE = 4


def get_train_ds_config(offload,
dtype,
stage=2,
enable_hybrid_engine=False,
inference_tp_size=1,
Expand All @@ -24,6 +25,12 @@ def get_train_ds_config(offload,
tb_name=""):

device = "cpu" if offload else "none"
if dtype == "fp16":
data_type = "fp16"
dtype_config = {"enabled": True, "loss_scale_window": 100}
elif dtype == "bf16":
data_type = "bfloat16"
dtype_config = {"enabled": True}
zero_opt_dict = {
"stage": stage,
"offload_param": {
Expand All @@ -39,18 +46,15 @@ def get_train_ds_config(offload,
}
if enable_mixed_precision_lora:
zero_opt_dict["zero_quantized_nontrainable_weights"] = True
if dist.get_world_size() != torch.cuda.device_count():
zero_opt_dict["zero_hpz_partition_size"] = torch.cuda.device_count(
)
if dist.get_world_size() != get_accelerator().device_count():
zero_opt_dict["zero_hpz_partition_size"] = get_accelerator(
).device_count()
return {
"train_batch_size": GLOBAL_BATCH_SIZE,
"train_micro_batch_size_per_gpu": MICRO_BATCH_SIZE,
"steps_per_print": 10,
"zero_optimization": zero_opt_dict,
"fp16": {
"enabled": True,
"loss_scale_window": 100
},
data_type: dtype_config,
"gradient_clipping": 1.0,
"prescale_gradients": False,
"wall_clock_breakdown": False,
Expand All @@ -70,8 +74,16 @@ def get_train_ds_config(offload,
}


def get_eval_ds_config(offload, stage=0):
def get_eval_ds_config(offload, dtype, stage=0):
device = "cpu" if offload else "none"
if dtype == "fp16":
data_type = "fp16"
dtype_config = {
"enabled": True,
}
elif dtype == "bf16":
data_type = "bfloat16"
dtype_config = {"enabled": True}
zero_opt_dict = {
"stage": stage,
"stage3_param_persistence_threshold": 1e4,
Expand All @@ -85,9 +97,7 @@ def get_eval_ds_config(offload, stage=0):
"train_micro_batch_size_per_gpu": MICRO_BATCH_SIZE,
"steps_per_print": 10,
"zero_optimization": zero_opt_dict,
"fp16": {
"enabled": True
},
data_type: dtype_config,
"gradient_clipping": 1.0,
"prescale_gradients": False,
"wall_clock_breakdown": False
Expand Down
Loading