Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, get_optimizer_grouped_parameters, save_zero_three_model, load_hf_tokenizer
from utils.ds_utils import get_train_ds_config
from utils.module.lora import convert_linear_layer_to_lora, convert_lora_to_linear_layer, only_optimize_lora_parameters, make_model_gradient_checkpointing_compatible
from utils.model.model_utils import create_hf_model
from utils.model.model_utils import create_hf_model, causal_lm_model_to_fp32_loss
from utils.perf import print_throughput


Expand Down Expand Up @@ -178,6 +178,12 @@ def parse_args():
help=
"Initial LoRA learning rate (after the potential warmup period) to use."
)
## low precision
parser.add_argument(
'--compute_fp32_loss',
action='store_true',
help='Relevant for low precision dtypes (fp16, bf16, etc.). '
'If specified, loss is calculated in fp32.')
## Tensorboard logging
parser.add_argument('--enable_tensorboard',
action='store_true',
Expand Down Expand Up @@ -234,6 +240,12 @@ def main():
ds_config,
dropout=args.dropout)

if args.compute_fp32_loss:
print_rank_0(
f"Using model {model.__class__.__name__} with loss in fp32",
args.global_rank)
causal_lm_model_to_fp32_loss(model)

if args.lora_dim > 0:
model = convert_linear_layer_to_lora(model, args.lora_module_name,
args.lora_dim)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,12 @@ def parse_args():
help=
"Initial LoRA learning rate (after the potential warmup period) to use."
)
## low precision
parser.add_argument(
'--compute_fp32_loss',
action='store_true',
help='Relevant for low precision dtypes (fp16, bf16, etc.). '
'If specified, loss is calculated in fp32.')
## Tensorboard logging
parser.add_argument('--enable_tensorboard',
action='store_true',
Expand Down Expand Up @@ -226,7 +232,9 @@ def main():
tokenizer,
ds_config,
args.num_padding_at_beginning,
dropout=args.dropout)
dropout=args.dropout,
zero_stage=args.zero_stage,
compute_fp32_loss=args.compute_fp32_loss)

if args.lora_dim > 0:
rm_model = convert_linear_layer_to_lora(rm_model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,13 @@ def parse_args():
'--enable_mixed_precision_lora',
action='store_true',
help='Enable Mixed Precision ZeRO++ for training and generation.')
## low precision
parser.add_argument(
'--compute_fp32_loss',
action='store_true',
help='Relevant for low precision dtypes (fp16, bf16, etc.). '
'If specified, loss is calculated in fp32.'
'This applies for both actor and critic models.')
## Tensorboard logging
parser.add_argument('--enable_tensorboard',
action='store_true',
Expand Down Expand Up @@ -572,13 +579,13 @@ def main():
average_reward / inner_iter,
global_step=step)
writer.add_scalar('actor_loss',
actor_loss,
actor_loss.item(),
global_step=step)
writer.add_scalar('actor_loss_sum',
actor_loss_sum,
global_step=step)
writer.add_scalar('critic_loss',
critic_loss,
critic_loss.item(),
global_step=step)
writer.add_scalar('critic_loss_sum',
critic_loss_sum,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(self, rlhf_engine, args):
self.end_of_conversation_token_id = self.tokenizer(
args.end_of_conversation_token)['input_ids'][-1]
self.z3_enabled = args.actor_zero_stage == 3
self.compute_fp32_loss = self.args.compute_fp32_loss

# Those value can be changed
self.kl_ctl = 0.1
Expand Down Expand Up @@ -139,6 +140,9 @@ def generate_experience(self, prompts, mask, step):

logits = output.logits
logits_ref = output_ref.logits
if self.compute_fp32_loss:
logits = logits.to(torch.float)
logits_ref = logits_ref.to(torch.float)

self.generate_time = generate_end - generate_start

Expand Down Expand Up @@ -271,6 +275,9 @@ def critic_loss_fn(self, values, old_values, returns, mask):
old_values - self.cliprange_value,
old_values + self.cliprange_value,
)
if self.compute_fp32_loss:
values = values.float()
values_clipped = values_clipped.float()
vf_loss1 = (values - returns)**2
vf_loss2 = (values_clipped - returns)**2
vf_loss = 0.5 * torch.sum(
Expand Down
61 changes: 59 additions & 2 deletions applications/DeepSpeed-Chat/training/utils/model/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,61 @@ def configure_dropout(model_config, dropout):
setattr(model_config, key, dropout)


def causal_lm_model_to_fp32_loss(model):
""" Convert CausalLM model to calculate loss in fp32 """

def causal_lm_forward(
input_ids=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**deprecated_arguments,
):
output = model.__original_forward__(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
labels=None,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)

return_dict = isinstance(output, dict)
lm_logits = output.logits if return_dict else output[0]
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].float().contiguous()
shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size),
shift_labels.view(batch_size * seq_length))

if not return_dict:
# re-pack output with fp32 loss
return ((loss, ) + output) if loss is not None else output

output.loss = loss
return output

model.__original_forward__ = model.forward
model.forward = causal_lm_forward


def create_hf_model(model_class,
model_name_or_path,
tokenizer,
Expand Down Expand Up @@ -64,7 +119,8 @@ def create_critic_model(model_name_or_path,
num_padding_at_beginning=0,
rlhf_training=False,
dropout=None,
zero_stage=0):
zero_stage=0,
compute_fp32_loss=False):
# OPT model family always put a padding token at the beginning of the sequence,
# we did not see this in other models but not sure if it is a general rule

Expand All @@ -80,7 +136,8 @@ def create_critic_model(model_name_or_path,
critic_model = RewardModel(
critic_model,
tokenizer,
num_padding_at_beginning=num_padding_at_beginning)
num_padding_at_beginning=num_padding_at_beginning,
compute_fp32_loss=compute_fp32_loss)

if rlhf_training:
# load critic model from checkpoint
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
## https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py
class RewardModel(nn.Module):

def __init__(self, base_model, tokenizer, num_padding_at_beginning=0):
def __init__(self,
base_model,
tokenizer,
num_padding_at_beginning=0,
compute_fp32_loss=False):
super().__init__()
self.config = base_model.config
self.num_padding_at_beginning = num_padding_at_beginning
Expand All @@ -27,6 +31,7 @@ def __init__(self, base_model, tokenizer, num_padding_at_beginning=0):
self.v_head = nn.Linear(self.config.n_embd, 1, bias=False)
self.rwtranrsformer = base_model
self.PAD_ID = tokenizer.pad_token_id
self.compute_fp32_loss = compute_fp32_loss

def gradient_checkpointing_enable(self):
self.rwtranrsformer.gradient_checkpointing_enable()
Expand Down Expand Up @@ -73,7 +78,7 @@ def forward(self,
rejected_rewards = rewards[bs:]

# Compute pairwise loss. Only backprop on the different tokens before padding
loss = 0
loss = 0.
for i in range(bs):
chosen_id = chosen_ids[i]
rejected_id = rejected_ids[i]
Expand Down Expand Up @@ -104,6 +109,9 @@ def forward(self,
chosen_reward[c_ind - 1]) #use the end score for reference
rejected_mean_scores.append(rejected_reward[r_ind - 1])

if self.compute_fp32_loss:
c_truncated_reward = c_truncated_reward.float()
r_truncated_reward = r_truncated_reward.float()
loss += -torch.nn.functional.logsigmoid(c_truncated_reward -
r_truncated_reward).mean()

Expand Down