Skip to content

Commit

Permalink
Add score scaling/normalization/clipping (#560)
Browse files Browse the repository at this point in the history
* Add reward/score scaling/normalization/clipping

* Run pre-commit to fix styles and remove some dupe code

* Make sure score module and pretrained_model have the same dtype

* Add multi_adapter_rl_v2.py

* Add log_with

* Add more verbose help message for use_score_norm

* Fix score clipping for float16

* Minor fix
  • Loading branch information
zfang authored Aug 10, 2023
1 parent 25fd6f2 commit 3b2c820
Show file tree
Hide file tree
Showing 8 changed files with 275 additions and 16 deletions.
19 changes: 19 additions & 0 deletions docs/source/customization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,22 @@ if ds_plugin is not None and ds_plugin.is_zero3_init_enabled():
else:
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)
```


## Use score scaling/normalization/clipping
As suggested by [Secrets of RLHF in Large Language Models Part I: PPO](https://arxiv.org/abs/2307.04964), we support score (aka reward) scaling/normalization/clipping to improve training stability via `PPOConfig`:
```python
from trl import PPOConfig

ppo_config = {
use_score_scaling=True,
use_score_norm=True,
score_clip=0.5,
}
config = PPOConfig(**ppo_config)
```

To run `sentiment_tuning.py`, you can use the following command:
```
python examples/scripts/sentiment_tuning.py --log_with wandb --use_score_scaling --use_score_norm --score_clip 0.5
```
150 changes: 150 additions & 0 deletions examples/scripts/multi_adapter_rl_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional

import torch
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import BitsAndBytesConfig, HfArgumentParser, LlamaTokenizer

from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
from trl.core import LengthSampler


input_min_text_length = 6
input_max_text_length = 12


@dataclass
class ScriptArguments:
"""
The name of the Casual LM model we wish to fine with PPO
"""

model_name: Optional[str] = field(default="huggyllama/llama-7b", metadata={"help": "the model name"})
dataset_name: Optional[str] = field(default="Anthropic/hh-rlhf", metadata={"help": "the dataset name"})
rm_adapter: Optional[str] = field(
default="trl-lib/llama-7b-hh-rm-adapter", metadata={"help": "the rm adapter name"}
)
log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
use_safetensors: Optional[bool] = field(default=False, metadata={"help": "Use safetensors"})
seed: Optional[int] = field(default=0, metadata={"help": "the random seed"})
use_score_scaling: Optional[bool] = field(default=False, metadata={"help": "Use score scaling"})
use_score_norm: Optional[bool] = field(
default=False, metadata={"help": "Use score normalization. Only applicable if use_score_scaling is True"}
)
score_clip: Optional[float] = field(default=None, metadata={"help": "Score clipping"})


parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]


def create_and_prepare_dataset(tokenizer):
dataset = load_dataset(script_args.dataset_name, split="train[:1%]")

input_size = LengthSampler(input_min_text_length, input_max_text_length)

def tokenize(example):
text_size = input_size()
example["input_ids"] = tokenizer.encode(example["chosen"])[:text_size]
example["query"] = tokenizer.decode(example["input_ids"])
return example

dataset = dataset.map(tokenize, batched=False)
dataset.set_format("torch")
return dataset


lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
nf4_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLMWithValueHead.from_pretrained(
script_args.model_name,
device_map={"": 0},
peft_config=lora_config,
quantization_config=nf4_config,
reward_adapter=script_args.rm_adapter,
use_safetensors=script_args.use_safetensors,
)
tokenizer = LlamaTokenizer.from_pretrained(script_args.model_name)

tokenizer.pad_token = tokenizer.eos_token

dataset = create_and_prepare_dataset(tokenizer)


def collator(data):
return dict((key, [d[key] for d in data]) for key in data[0])


config = PPOConfig(
model_name=script_args.model_name,
log_with=script_args.log_with,
learning_rate=1e-5,
batch_size=8,
mini_batch_size=2,
gradient_accumulation_steps=2,
optimize_cuda_cache=True,
seed=script_args.seed,
use_score_scaling=script_args.use_score_scaling,
use_score_norm=script_args.use_score_norm,
score_clip=script_args.score_clip,
)

ppo_trainer = PPOTrainer(
config,
model,
ref_model=None,
tokenizer=tokenizer,
dataset=dataset,
data_collator=collator,
)

generation_kwargs = {
"top_k": 0.0,
"top_p": 0.9,
"do_sample": True,
"pad_token_id": tokenizer.pad_token_id,
}

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
question_tensors = batch["input_ids"]

response_tensors = ppo_trainer.generate(
question_tensors,
return_prompt=False,
**generation_kwargs,
)
batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)

# Compute reward score
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(ppo_trainer.accelerator.device)
raw_rewards = ppo_trainer.model.compute_reward_score(**inputs)
rewards = [raw_rewards[i, -1, 1] for i in range(len(raw_rewards))] # take last token

# Run PPO step
stats = ppo_trainer.step(question_tensors, response_tensors, rewards)
ppo_trainer.log_stats(stats, batch, rewards)
14 changes: 10 additions & 4 deletions examples/scripts/sentiment_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ class ScriptArguments:
default=1, metadata={"help": "the number of gradient accumulation steps"}
)
early_stopping: Optional[bool] = field(default=False, metadata={"help": "whether to early stop"})
target_kl: Optional[float] = field(default=6, metadata={"help": "kl target for early stopping"})
use_peft: Optional[bool] = field(default=False, metadata={"help": "whether to use peft"})
use_seq2seq: Optional[bool] = field(default=False, metadata={"help": "whether to use seq2seq models"})
kl_penalty: Optional[str] = field(
Expand All @@ -56,6 +55,11 @@ class ScriptArguments:
)
target_kl: Optional[float] = field(default=0.1, metadata={"help": "kl target for early stopping"})
seed: Optional[int] = field(default=0, metadata={"help": "the random seed"})
use_score_scaling: Optional[bool] = field(default=False, metadata={"help": "Use score scaling"})
use_score_norm: Optional[bool] = field(
default=False, metadata={"help": "Use score normalization. Only applicable if use_score_scaling is True"}
)
score_clip: Optional[float] = field(default=None, metadata={"help": "Score clipping"})


parser = HfArgumentParser(ScriptArguments)
Expand All @@ -72,8 +76,13 @@ class ScriptArguments:
target_kl=script_args.target_kl,
kl_penalty=script_args.kl_penalty,
seed=script_args.seed,
use_score_scaling=script_args.use_score_scaling,
use_score_norm=script_args.use_score_norm,
score_clip=script_args.score_clip,
)

# set seed before initializing value head for deterministic eval
set_seed(config.seed)

# We then define the arguments to pass to the sentiment analysis pipeline.
# We set `return_all_scores` to True to get the sentiment score for each token.
Expand Down Expand Up @@ -127,9 +136,6 @@ def collator(data):
return dict((key, [d[key] for d in data]) for key in data[0])


# set seed before initializing value head for deterministic eval
set_seed(config.seed)

# Now let's build the model, the reference model, and the tokenizer.
if not script_args.use_peft:
ref_model = trl_model_class.from_pretrained(config.model_name)
Expand Down
5 changes: 4 additions & 1 deletion trl/models/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,10 @@ def add_and_load_reward_modeling_adapter(self, adapter_model_id, adapter_name="r
num_labels, hidden_dim = score_dict["weight"].shape
has_bias = any(["bias" in name for name in adapter_state_dict.keys()])

self.score = nn.Linear(hidden_dim, num_labels, bias=has_bias).to(self._get_current_device())
self.score = nn.Linear(hidden_dim, num_labels, bias=has_bias).to(
device=self._get_current_device(),
dtype=self.pretrained_model.dtype,
)
self.score.load_state_dict(score_dict)

# load the adapter to the model
Expand Down
8 changes: 7 additions & 1 deletion trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@

# There is a circular import in the PPOTrainer if we let isort sort these
# isort: off
from .utils import AdaptiveKLController, FixedKLController, ConstantLengthDataset, DataCollatorForCompletionOnlyLM
from .utils import (
AdaptiveKLController,
FixedKLController,
ConstantLengthDataset,
DataCollatorForCompletionOnlyLM,
RunningMoments,
)

# isort: on

Expand Down
5 changes: 5 additions & 0 deletions trl/trainer/ppo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ class PPOConfig(object):
ratio_threshold: Optional[float] = field(
default=10.0, metadata={"help": "Skip mini-batches with high PPO ratios that can cause loss spikes"}
)
use_score_scaling: Optional[bool] = field(default=False, metadata={"help": "Use score scaling"})
use_score_norm: Optional[bool] = field(
default=False, metadata={"help": "Use score normalization. Only applicable if use_score_scaling is True"}
)
score_clip: Optional[float] = field(default=None, metadata={"help": "Score clipping"})

def __post_init__(self):
if self.forward_batch_size is not None:
Expand Down
29 changes: 20 additions & 9 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
)
from ..import_utils import is_torch_greater_2_0
from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper, create_reference_model
from . import AdaptiveKLController, BaseTrainer, FixedKLController, PPOConfig
from . import AdaptiveKLController, BaseTrainer, FixedKLController, PPOConfig, RunningMoments


MODEL_CARD_TEMPLATE = """---
Expand Down Expand Up @@ -344,6 +344,8 @@ def __init__(

PPODecorators.optimize_cuda_cache = self.config.optimize_cuda_cache

self.running = RunningMoments(self.accelerator)

def _filter_kwargs(self, kwargs, target_func):
"""
filter the keyword arguments that are supported by the target function.
Expand Down Expand Up @@ -388,7 +390,7 @@ def _set_signature_columns_if_needed(self):
signature = inspect.signature(self.model.forward)
self._signature_columns = list(signature.parameters.keys())
# label => sentiment | we need query and response for logging purpose
self._signature_columns += list(set(["label", "query", "response"]))
self._signature_columns += ["label", "query", "response"]

# Adapted from transformers.Trainer._remove_unused_columns
def _remove_unused_columns(self, dataset: "Dataset"):
Expand Down Expand Up @@ -588,11 +590,24 @@ def step(
bs = self.config.batch_size

queries, responses, scores = self._step_safety_checker(bs, queries, responses, scores)
scores = torch.tensor(scores)
if self.config.use_score_scaling:
# Score scaling
scores_mean, scores_std = self.running.update(scores)
if self.config.use_score_norm:
scores = (scores - self.running.mean) / self.running.std
else:
scores /= self.running.std

if self.config.score_clip is not None:
# Score clipping
scores_dtype = scores.dtype
scores = torch.clip(scores.float(), -self.config.score_clip, self.config.score_clip).to(dtype=scores_dtype)

# if we want to push best model to the hub
if hasattr(self, "highest_reward"):
if self.compare_step % self.config.compare_steps == 0:
curr_mean_reward = torch.tensor(scores).mean()
curr_mean_reward = scores.mean()
# if the best reward ever seen
if curr_mean_reward > self.highest_reward:
self.highest_reward = curr_mean_reward
Expand Down Expand Up @@ -1186,8 +1201,8 @@ def record_step_stats(self, kl_coef: float, **data):
mean_non_score_reward = masked_mean(
data["non_score_reward"], mask
) # non_score_reward is size `batch_size`, `response_length`
mean_scores = torch.stack(data["scores"]).mean() # scores is size `batch_size`
std_scores = torch.stack(data["scores"]).std()
mean_scores = data["scores"].mean() # scores is size `batch_size`
std_scores = data["scores"].std()

if mean_kl.item() < -1.0:
# warn users
Expand Down Expand Up @@ -1281,10 +1296,6 @@ def log_stats(
logs["env/reward_std"] = torch.std(rewards).cpu().numpy().item()
logs["env/reward_dist"] = rewards.cpu().numpy()

logs["env/reward_mean"] = torch.mean(rewards).cpu().numpy().item()
logs["env/reward_std"] = torch.std(rewards).cpu().numpy().item()
logs["env/reward_dist"] = rewards.cpu().numpy()

if self.config.log_with == "tensorboard":
# update the current step
self.current_step += 1
Expand Down
Loading

0 comments on commit 3b2c820

Please sign in to comment.