Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add score scaling/normalization/clipping #560

Merged
merged 9 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions docs/source/customization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,22 @@ if ds_plugin is not None and ds_plugin.is_zero3_init_enabled():
else:
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)
```


## Use score scaling/normalization/clipping
As suggested by [Secrets of RLHF in Large Language Models Part I: PPO](https://arxiv.org/abs/2307.04964), we support score (aka reward) scaling/normalization/clipping to improve training stability via `PPOConfig`:
```python
from trl import PPOConfig

ppo_config = {
use_score_scaling=True,
use_score_norm=True,
score_clip=0.5,
}
config = PPOConfig(**ppo_config)
```

To run `sentiment_tuning.py`, you can use the following command:
```
python examples/scripts/sentiment_tuning.py --log_with wandb --use_score_scaling --use_score_norm --score_clip 0.5
```
150 changes: 150 additions & 0 deletions examples/scripts/multi_adapter_rl_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional

import torch
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import BitsAndBytesConfig, HfArgumentParser, LlamaTokenizer

from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
from trl.core import LengthSampler


input_min_text_length = 6
input_max_text_length = 12


@dataclass
class ScriptArguments:
"""
The name of the Casual LM model we wish to fine with PPO
"""

model_name: Optional[str] = field(default="huggyllama/llama-7b", metadata={"help": "the model name"})
dataset_name: Optional[str] = field(default="Anthropic/hh-rlhf", metadata={"help": "the dataset name"})
rm_adapter: Optional[str] = field(
default="trl-lib/llama-7b-hh-rm-adapter", metadata={"help": "the rm adapter name"}
)
log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
use_safetensors: Optional[bool] = field(default=False, metadata={"help": "Use safetensors"})
seed: Optional[int] = field(default=0, metadata={"help": "the random seed"})
use_score_scaling: Optional[bool] = field(default=False, metadata={"help": "Use score scaling"})
use_score_norm: Optional[bool] = field(
default=False, metadata={"help": "Use score normalization. Only applicable if use_score_scaling is True"}
)
score_clip: Optional[float] = field(default=None, metadata={"help": "Score clipping"})


parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]


def create_and_prepare_dataset(tokenizer):
dataset = load_dataset(script_args.dataset_name, split="train[:1%]")

input_size = LengthSampler(input_min_text_length, input_max_text_length)

def tokenize(example):
text_size = input_size()
example["input_ids"] = tokenizer.encode(example["chosen"])[:text_size]
example["query"] = tokenizer.decode(example["input_ids"])
return example

dataset = dataset.map(tokenize, batched=False)
dataset.set_format("torch")
return dataset


lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
nf4_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLMWithValueHead.from_pretrained(
script_args.model_name,
device_map={"": 0},
peft_config=lora_config,
quantization_config=nf4_config,
reward_adapter=script_args.rm_adapter,
use_safetensors=script_args.use_safetensors,
)
tokenizer = LlamaTokenizer.from_pretrained(script_args.model_name)

tokenizer.pad_token = tokenizer.eos_token

dataset = create_and_prepare_dataset(tokenizer)


def collator(data):
return dict((key, [d[key] for d in data]) for key in data[0])


config = PPOConfig(
model_name=script_args.model_name,
log_with=script_args.log_with,
learning_rate=1e-5,
batch_size=8,
mini_batch_size=2,
gradient_accumulation_steps=2,
optimize_cuda_cache=True,
seed=script_args.seed,
use_score_scaling=script_args.use_score_scaling,
use_score_norm=script_args.use_score_norm,
score_clip=script_args.score_clip,
)

ppo_trainer = PPOTrainer(
config,
model,
ref_model=None,
tokenizer=tokenizer,
dataset=dataset,
data_collator=collator,
)

generation_kwargs = {
"top_k": 0.0,
"top_p": 0.9,
"do_sample": True,
"pad_token_id": tokenizer.pad_token_id,
}

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
question_tensors = batch["input_ids"]

response_tensors = ppo_trainer.generate(
question_tensors,
return_prompt=False,
**generation_kwargs,
)
batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)

# Compute reward score
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(ppo_trainer.accelerator.device)
raw_rewards = ppo_trainer.model.compute_reward_score(**inputs)
rewards = [raw_rewards[i, -1, 1] for i in range(len(raw_rewards))] # take last token

# Run PPO step
stats = ppo_trainer.step(question_tensors, response_tensors, rewards)
ppo_trainer.log_stats(stats, batch, rewards)
14 changes: 10 additions & 4 deletions examples/scripts/sentiment_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ class ScriptArguments:
default=1, metadata={"help": "the number of gradient accumulation steps"}
)
early_stopping: Optional[bool] = field(default=False, metadata={"help": "whether to early stop"})
target_kl: Optional[float] = field(default=6, metadata={"help": "kl target for early stopping"})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This field seems to have been removed by mistake?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Younes,

You will find that target_kl already exists on L57 with a much smaller value.

I dug deeper and found that PPOConfig has two configs target and target_kl, where target has a default value of 6. So I assume the first duplicate target_kl config here was meant to be target. However, target is NOT used to populate PPOConfig at L64, so I just removed it.

Regards,

Felix

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great point, thank you !

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is actually a bug from here: 1620da3
we overloaded the target_kl term - we should rename it!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lvwerra as much as I love introducing bugs into trl. I think this time it was @younesbelkada , in the Big refactor of examples and documentation (#509). Here

I agree to rename to early_stop_kl, or something

use_peft: Optional[bool] = field(default=False, metadata={"help": "whether to use peft"})
use_seq2seq: Optional[bool] = field(default=False, metadata={"help": "whether to use seq2seq models"})
kl_penalty: Optional[str] = field(
Expand All @@ -56,6 +55,11 @@ class ScriptArguments:
)
target_kl: Optional[float] = field(default=0.1, metadata={"help": "kl target for early stopping"})
seed: Optional[int] = field(default=0, metadata={"help": "the random seed"})
use_score_scaling: Optional[bool] = field(default=False, metadata={"help": "Use score scaling"})
use_score_norm: Optional[bool] = field(
default=False, metadata={"help": "Use score normalization. Only applicable if use_score_scaling is True"}
)
score_clip: Optional[float] = field(default=None, metadata={"help": "Score clipping"})


parser = HfArgumentParser(ScriptArguments)
Expand All @@ -72,8 +76,13 @@ class ScriptArguments:
target_kl=script_args.target_kl,
kl_penalty=script_args.kl_penalty,
seed=script_args.seed,
use_score_scaling=script_args.use_score_scaling,
use_score_norm=script_args.use_score_norm,
score_clip=script_args.score_clip,
)

# set seed before initializing value head for deterministic eval
set_seed(config.seed)

# We then define the arguments to pass to the sentiment analysis pipeline.
# We set `return_all_scores` to True to get the sentiment score for each token.
Expand Down Expand Up @@ -127,9 +136,6 @@ def collator(data):
return dict((key, [d[key] for d in data]) for key in data[0])


# set seed before initializing value head for deterministic eval
set_seed(config.seed)

# Now let's build the model, the reference model, and the tokenizer.
if not script_args.use_peft:
ref_model = trl_model_class.from_pretrained(config.model_name)
Expand Down
5 changes: 4 additions & 1 deletion trl/models/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,10 @@ def add_and_load_reward_modeling_adapter(self, adapter_model_id, adapter_name="r
num_labels, hidden_dim = score_dict["weight"].shape
has_bias = any(["bias" in name for name in adapter_state_dict.keys()])

self.score = nn.Linear(hidden_dim, num_labels, bias=has_bias).to(self._get_current_device())
self.score = nn.Linear(hidden_dim, num_labels, bias=has_bias).to(
device=self._get_current_device(),
dtype=self.pretrained_model.dtype,
)
self.score.load_state_dict(score_dict)

# load the adapter to the model
Expand Down
8 changes: 7 additions & 1 deletion trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@

# There is a circular import in the PPOTrainer if we let isort sort these
# isort: off
from .utils import AdaptiveKLController, FixedKLController, ConstantLengthDataset, DataCollatorForCompletionOnlyLM
from .utils import (
AdaptiveKLController,
FixedKLController,
ConstantLengthDataset,
DataCollatorForCompletionOnlyLM,
RunningMoments,
)

# isort: on

Expand Down
5 changes: 5 additions & 0 deletions trl/trainer/ppo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ class PPOConfig(object):
ratio_threshold: Optional[float] = field(
default=10.0, metadata={"help": "Skip mini-batches with high PPO ratios that can cause loss spikes"}
)
use_score_scaling: Optional[bool] = field(default=False, metadata={"help": "Use score scaling"})
use_score_norm: Optional[bool] = field(
default=False, metadata={"help": "Use score normalization. Only applicable if use_score_scaling is True"}
)
score_clip: Optional[float] = field(default=None, metadata={"help": "Score clipping"})

def __post_init__(self):
if self.forward_batch_size is not None:
Expand Down
29 changes: 20 additions & 9 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
)
from ..import_utils import is_torch_greater_2_0
from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper, create_reference_model
from . import AdaptiveKLController, BaseTrainer, FixedKLController, PPOConfig
from . import AdaptiveKLController, BaseTrainer, FixedKLController, PPOConfig, RunningMoments


MODEL_CARD_TEMPLATE = """---
Expand Down Expand Up @@ -344,6 +344,8 @@ def __init__(

PPODecorators.optimize_cuda_cache = self.config.optimize_cuda_cache

self.running = RunningMoments(self.accelerator)

def _filter_kwargs(self, kwargs, target_func):
"""
filter the keyword arguments that are supported by the target function.
Expand Down Expand Up @@ -388,7 +390,7 @@ def _set_signature_columns_if_needed(self):
signature = inspect.signature(self.model.forward)
self._signature_columns = list(signature.parameters.keys())
# label => sentiment | we need query and response for logging purpose
self._signature_columns += list(set(["label", "query", "response"]))
self._signature_columns += ["label", "query", "response"]

# Adapted from transformers.Trainer._remove_unused_columns
def _remove_unused_columns(self, dataset: "Dataset"):
Expand Down Expand Up @@ -588,11 +590,24 @@ def step(
bs = self.config.batch_size

queries, responses, scores = self._step_safety_checker(bs, queries, responses, scores)
scores = torch.tensor(scores)
if self.config.use_score_scaling:
# Score scaling
scores_mean, scores_std = self.running.update(scores)
if self.config.use_score_norm:
scores = (scores - self.running.mean) / self.running.std
else:
scores /= self.running.std

if self.config.score_clip is not None:
# Score clipping
scores_dtype = scores.dtype
scores = torch.clip(scores.float(), -self.config.score_clip, self.config.score_clip).to(dtype=scores_dtype)

# if we want to push best model to the hub
if hasattr(self, "highest_reward"):
if self.compare_step % self.config.compare_steps == 0:
curr_mean_reward = torch.tensor(scores).mean()
curr_mean_reward = scores.mean()
# if the best reward ever seen
if curr_mean_reward > self.highest_reward:
self.highest_reward = curr_mean_reward
Expand Down Expand Up @@ -1186,8 +1201,8 @@ def record_step_stats(self, kl_coef: float, **data):
mean_non_score_reward = masked_mean(
data["non_score_reward"], mask
) # non_score_reward is size `batch_size`, `response_length`
mean_scores = torch.stack(data["scores"]).mean() # scores is size `batch_size`
std_scores = torch.stack(data["scores"]).std()
mean_scores = data["scores"].mean() # scores is size `batch_size`
std_scores = data["scores"].std()

if mean_kl.item() < -1.0:
# warn users
Expand Down Expand Up @@ -1281,10 +1296,6 @@ def log_stats(
logs["env/reward_std"] = torch.std(rewards).cpu().numpy().item()
logs["env/reward_dist"] = rewards.cpu().numpy()

logs["env/reward_mean"] = torch.mean(rewards).cpu().numpy().item()
logs["env/reward_std"] = torch.std(rewards).cpu().numpy().item()
logs["env/reward_dist"] = rewards.cpu().numpy()

if self.config.log_with == "tensorboard":
# update the current step
self.current_step += 1
Expand Down
Loading