Skip to content

Commit 5e92f77

Browse files
committed
add hugchat reward training
1 parent 1b9f416 commit 5e92f77

File tree

10 files changed

+632
-48
lines changed

10 files changed

+632
-48
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
path=/wjn/pre-trained-lm/gpt2
2+
3+
model_name=gpt2
4+
5+
data_path=/wjn/nlp_task_datasets/rlhf_preference # consists of preference_train.json, preference_dev.json, preference_test.json
6+
7+
8+
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
9+
python3 -m torch.distributed.launch --nproc_per_node=8 --master_port 6013 hugnlp_runner.py \
10+
--model_name_or_path=$path \
11+
--data_dir=$data_path \
12+
--max_seq_length=512 \
13+
--output_dir=./outputs/rlhf/$model_name/ \
14+
--do_train \
15+
--do_eval \
16+
--do_predict \
17+
--per_device_train_batch_size=8 \
18+
--per_device_eval_batch_size=1 \
19+
--evaluation_strategy=steps \
20+
--save_strategy=steps \
21+
--gradient_accumulation_steps=1 \
22+
--learning_rate=1e-05 \
23+
--logging_steps=10000000 \
24+
--eval_steps=3000 \
25+
--save_steps=3000 \
26+
--save_total_limit=10 \
27+
--num_train_epochs=3 \
28+
--report_to=none \
29+
--task_name=pairwise_reward \
30+
--task_type=rl_reward \
31+
--model_type=gpt2 \
32+
--exp_name=preference_reward \
33+
--warmup_steps=6000 \
34+
--load_best_model_at_end \
35+
--metric_for_best_model=acc \
36+
--ignore_data_skip \
37+
--remove_unused_columns=False \
38+
--cache_dir=/wjn/.cache \
39+
--overwrite_output_dir \
40+
# --deepspeed=./deepspeed/ds_config_fp16_z1.json \
41+
# --fp16

evaluators/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from evaluators.token_classification_evaluator import TokenClassificationEvaluator
99
from evaluators.span_extraction_evaluator import SpanExtractionEvaluator
1010
from evaluators.multi_choice_evaluator import MultiChoiceEvaluator
11+
from evaluators.reinforcement_learning_evaluator import PairwiseRewardEvaluator
1112

1213
# Models for pre-training
1314
PRETRAIN_EVALUATOR_CLASSES = {
@@ -57,6 +58,11 @@
5758
"code_generation": None,
5859
}
5960

61+
REINFORCEMENT_MODEL_CLASSES = {
62+
"causal_actor": None,
63+
"auto_critic": None,
64+
"rl_reward": PairwiseRewardEvaluator,
65+
}
6066

6167
# task_type 负责对应model类型
6268
OTHER_EVALUATOR_CLASSES = {
@@ -101,6 +107,7 @@
101107
SPAN_EXTRACTION_EVALUATOR_CLASSES,
102108
FEWSHOT_EVALUATOR_CLASSES,
103109
CODE_EVALUATOR_CLASSES,
110+
REINFORCEMENT_MODEL_CLASSES,
104111
OTHER_EVALUATOR_CLASSES
105112
]
106113

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# -*- coding: utf-8 -*-
2+
# @Time : 2023/5/6 8:09 p.m.
3+
# @Author : JianingWang
4+
# @File : reinforcement_learning_evaluator.py
5+
6+
import json
7+
import os.path
8+
import math
9+
import torch
10+
import numpy as np
11+
from tqdm import tqdm
12+
from typing import Dict, Union, Any, Optional, Callable, List, Tuple, Iterator
13+
import datasets
14+
from datasets import Dataset
15+
from config import DataTrainingArguments, TrainingArguments, ModelArguments
16+
from hugnlp_trainer import HugTrainer
17+
from processors.ProcessorBase import DataProcessor
18+
from evaluators.EvaluatorBase import NO_GENERATE, DO_GENERATE, Evaluator, ClassificationEvaluator, GenerationEvaluator
19+
from metrics.classification_metric import ClassificationMetric
20+
from tools.runner_utils.log_util import logging
21+
from tools.computations.softmax import softmax
22+
from tools.model_utils.calibrate import CausalCLSCalibrator
23+
24+
logger = logging.getLogger(__name__)
25+
26+
27+
"""
28+
Evaluator for pair-wise reward model
29+
"""
30+
class PairwiseRewardEvaluator(ClassificationEvaluator):
31+
32+
def __init__(
33+
self,
34+
model_args: ModelArguments,
35+
data_args: DataTrainingArguments,
36+
training_args: TrainingArguments,
37+
processor: DataProcessor,
38+
model: torch.nn.Module,
39+
trainer: Optional[HugTrainer] = None,
40+
eval_dataset: Optional[Dataset] = None,
41+
test_dataset: Optional[Dataset] = None,
42+
) -> None:
43+
super().__init__(model_args, data_args, training_args, processor, model, trainer, eval_dataset, test_dataset)
44+
self.paradigm = NO_GENERATE
45+
46+
47+
def default_compute_metrics(self, eval_predictions):
48+
"""
49+
Design for the default metrics calculation for the current task.
50+
Note:
51+
- If the task processor has attribution of 'compute_metrics', this function will not be used.
52+
- If this pre-built function can match your demand, you can omit the definition of 'compute_metrics' in your processor.
53+
"""
54+
examples = self.eval_dataset
55+
labels = examples["label"]
56+
57+
golden = {}
58+
# predictions: {"xx": "xxx", ...}
59+
predictions, _ = self.get_best_and_topk(eval_predictions[0], examples, stage="dev")
60+
for example in examples:
61+
try:
62+
idx = int(example["idx"])
63+
except:
64+
idx = int(example["idx"].split("-")[1]) # e.g., "dev-12" -> "12"
65+
66+
golden[idx] = example["label"]
67+
68+
all_metrics = {
69+
"eval_macro_f1": 0.,
70+
"eval_acc": 0.,
71+
}
72+
73+
metric = ClassificationMetric()
74+
gold = {k: v for k, v in golden.items()}
75+
pred = {k: v for k, v in predictions.items()}
76+
score = metric.calc_metric(golden=gold, predictions=pred)
77+
acc, f1 = score["acc"], score["f1"]
78+
all_metrics["eval_macro_f1"] += f1
79+
all_metrics["eval_acc"] += acc
80+
return all_metrics
81+
82+
83+
def evaluate(self, test_dataset=None):
84+
85+
"""
86+
Each example has following two sequence:
87+
- chosen: the better response
88+
- rejected: the worse response
89+
We need the model assign high reward for chosen than rejected sequence.
90+
Thus, we calculate the accuracy that the reward value of chosen sequence derived from the reward model higher than the rejected sequence.
91+
"""
92+
eval_dataset = self.eval_dataset if test_dataset is not None else test_dataset
93+
all_chosen_values, all_rejected_values = list(), list()
94+
for ei, data in enumerate(tqdm(eval_dataset)):
95+
# chosen_input_ids, chosen_attention_mask = data["chosen_sequence"], data["chosen_attention_mask"]
96+
# rejected_input_ids, rejected_attention_mask = data["rejected_sequence"], data["rejected_attention_mask"]
97+
chosen_output = self.model(**data)
98+
chosen_values, rejected_values = chosen_output["chosen_values"], chosen_output["rejected_values"]
99+
all_chosen_values.extend(chosen_values.detach().cpu().numpy().tolist())
100+
all_rejected_values.extend(rejected_values.detach().cpu().numpy().tolist())
101+
102+
metrics = dict()
103+
acc = 0.
104+
for chosen_value, rejected_value in zip(all_chosen_values, all_rejected_values):
105+
if chosen_value >= rejected_value:
106+
acc += 1
107+
metrics["acc"] = round(acc / len(all_chosen_values), 4)
108+
self.trainer.log_metrics("eval", metrics)
109+
self.trainer.save_metrics("eval", metrics)
110+
111+
112+
def predict(self):
113+
114+
self.evaluate(test_dataset=self.test_dataset)
115+
116+
def get_best_and_topk(self, logits, examples, topk=10, stage="dev"):
117+
pass

hugnlp_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def main():
5050
model_args, data_args, training_args, semi_training_args = parser.parse_args_into_dataclasses()
5151

5252
# Print hello world
53-
if training_args.local_rank == 0:
53+
if training_args.local_rank <= 0 or os.environ['LOCAL_RANK'] == "0":
5454
print_hello()
5555

5656
training_args.output_dir = os.path.join(training_args.output_dir, list(filter(None, model_args.model_name_or_path.split("/")))[-1])

loss/rl_loss.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
from typing import Optional
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
7+
tensor = tensor * mask
8+
tensor = tensor.sum(dim=dim)
9+
mask_sum = mask.sum(dim=dim)
10+
mean = tensor / (mask_sum + 1e-8)
11+
return mean
12+
13+
14+
class GPTLMLoss(nn.Module):
15+
"""
16+
GPT Language Model Loss
17+
"""
18+
19+
def __init__(self):
20+
super().__init__()
21+
self.loss = nn.CrossEntropyLoss()
22+
23+
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
24+
shift_logits = logits[..., :-1, :].contiguous()
25+
shift_labels = labels[..., 1:].contiguous()
26+
# Flatten the tokens
27+
return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
28+
29+
30+
class PolicyLoss(nn.Module):
31+
"""
32+
Policy Loss for PPO
33+
"""
34+
35+
def __init__(self, clip_eps: float = 0.2) -> None:
36+
super().__init__()
37+
self.clip_eps = clip_eps
38+
39+
def forward(self,
40+
log_probs: torch.Tensor,
41+
old_log_probs: torch.Tensor,
42+
advantages: torch.Tensor,
43+
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
44+
ratio = (log_probs - old_log_probs).exp()
45+
surr1 = ratio * advantages
46+
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
47+
loss = -torch.min(surr1, surr2)
48+
if action_mask is not None:
49+
loss = masked_mean(loss, action_mask)
50+
loss = loss.mean()
51+
return loss
52+
53+
54+
class ValueLoss(nn.Module):
55+
"""
56+
Value Loss for PPO
57+
"""
58+
59+
def __init__(self, clip_eps: float = 0.4) -> None:
60+
super().__init__()
61+
self.clip_eps = clip_eps
62+
63+
def forward(self,
64+
values: torch.Tensor,
65+
old_values: torch.Tensor,
66+
reward: torch.Tensor,
67+
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
68+
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
69+
surr1 = (values_clipped - reward)**2
70+
surr2 = (values - reward)**2
71+
loss = torch.max(surr1, surr2)
72+
loss = loss.mean()
73+
return 0.5 * loss
74+
75+
76+
class PPOPtxActorLoss(nn.Module):
77+
"""
78+
To Do:
79+
80+
PPO-ptx Actor Loss
81+
"""
82+
83+
def __init__(self, policy_clip_eps: float = 0.2, pretrain_coef: float = 0.0, pretrain_loss_fn=GPTLMLoss()) -> None:
84+
super().__init__()
85+
self.pretrain_coef = pretrain_coef
86+
self.policy_loss_fn = PolicyLoss(clip_eps=policy_clip_eps)
87+
self.pretrain_loss_fn = pretrain_loss_fn
88+
89+
def forward(self,
90+
log_probs: torch.Tensor,
91+
old_log_probs: torch.Tensor,
92+
advantages: torch.Tensor,
93+
lm_logits: torch.Tensor,
94+
lm_input_ids: torch.Tensor,
95+
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
96+
policy_loss = self.policy_loss_fn(log_probs, old_log_probs, advantages, action_mask=action_mask)
97+
lm_loss = self.pretrain_loss_fn(lm_logits, lm_input_ids)
98+
return policy_loss + self.pretrain_coef * lm_loss
99+
100+
101+
class LogSigLoss(nn.Module):
102+
"""
103+
Pairwise Loss for Reward Model
104+
Details: https://arxiv.org/abs/2203.02155
105+
"""
106+
107+
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
108+
probs = torch.sigmoid(chosen_reward - reject_reward)
109+
log_probs = torch.log(probs)
110+
loss = -log_probs.mean()
111+
return loss
112+
113+
114+
class LogExpLoss(nn.Module):
115+
"""
116+
Pairwise Loss for Reward Model
117+
Details: https://arxiv.org/abs/2204.05862
118+
"""
119+
120+
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
121+
loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean()
122+
return loss

models/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@
7373

7474
from models.reinforcement_learning.actor import CausalActor
7575
from models.reinforcement_learning.critic import AutoModelCritic
76-
from models.reinforcement_learning.reward_model import AutoModelReward
76+
from models.reinforcement_learning.reward_model import (
77+
RobertaForReward, GPT2ForReward
78+
)
7779

7880
# Models for pre-training
7981
PRETRAIN_MODEL_CLASSES = {
@@ -199,7 +201,13 @@
199201
REINFORCEMENT_MODEL_CLASSES = {
200202
"causal_actor": CausalActor,
201203
"auto_critic": AutoModelCritic,
202-
"auto_reward": AutoModelReward,
204+
"rl_reward": {
205+
"roberta": RobertaForReward,
206+
"gpt2": GPT2ForReward,
207+
"gpt-neo": None,
208+
"opt": None,
209+
"llama": None,
210+
}
203211
}
204212

205213
# task_type 负责对应model类型

0 commit comments

Comments
 (0)