Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add score scaling/normalization/clipping #560

Merged
merged 9 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/source/customization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,21 @@ Note that using `python -m torch.distributed.launch --nproc_per_node=1 reward_su
ValueError: Some specified arguments are not used by the HfArgumentParser: ['--local-rank=0']
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 194889) of binary: /home/ubuntu/miniconda3/envs/trl/bin/python
```

## Use score scaling/normalization/clipping
As suggested by [Secrets of RLHF in Large Language Models Part I: PPO](https://arxiv.org/abs/2307.04964), we support score (aka reward) scaling/normalization/clipping to improve training stability via `PPOConfig`:
```python
from trl import PPOConfig

ppo_config = {
use_score_scaling=True,
use_score_norm=True,
score_clip=0.5,
}
config = PPOConfig(**ppo_config)
```

To run `sentiment_tuning.py`, you can use the following command:
```
python examples/scripts/sentiment_tuning.py --log_with wandb --use_score_scaling --use_score_norm --score_clip 0.5
```
7 changes: 6 additions & 1 deletion examples/scripts/sentiment_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ class ScriptArguments:
default=1, metadata={"help": "the number of gradient accumulation steps"}
)
early_stopping: Optional[bool] = field(default=False, metadata={"help": "whether to early stop"})
target_kl: Optional[float] = field(default=6, metadata={"help": "kl target for early stopping"})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This field seems to have been removed by mistake?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Younes,

You will find that target_kl already exists on L57 with a much smaller value.

I dug deeper and found that PPOConfig has two configs target and target_kl, where target has a default value of 6. So I assume the first duplicate target_kl config here was meant to be target. However, target is NOT used to populate PPOConfig at L64, so I just removed it.

Regards,

Felix

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great point, thank you !

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is actually a bug from here: 1620da3
we overloaded the target_kl term - we should rename it!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lvwerra as much as I love introducing bugs into trl. I think this time it was @younesbelkada , in the Big refactor of examples and documentation (#509). Here

I agree to rename to early_stop_kl, or something

use_peft: Optional[bool] = field(default=False, metadata={"help": "whether to use peft"})
use_seq2seq: Optional[bool] = field(default=False, metadata={"help": "whether to use seq2seq models"})
kl_penalty: Optional[str] = field(
Expand All @@ -56,6 +55,9 @@ class ScriptArguments:
)
target_kl: Optional[float] = field(default=0.1, metadata={"help": "kl target for early stopping"})
seed: Optional[int] = field(default=0, metadata={"help": "the random seed"})
use_score_scaling: Optional[bool] = field(default=False, metadata={"help": "Use score scaling"})
use_score_norm: Optional[bool] = field(default=False, metadata={"help": "Use score normalization"})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should clarify that this only works if use_score_scaling is also True otherwise it's actually ignored. we change the logic a bit in general

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

score_clip: Optional[float] = field(default=None, metadata={"help": "Score clipping"})


parser = HfArgumentParser(ScriptArguments)
Expand All @@ -72,6 +74,9 @@ class ScriptArguments:
target_kl=script_args.target_kl,
kl_penalty=script_args.kl_penalty,
seed=script_args.seed,
use_score_scaling=script_args.use_score_scaling,
use_score_norm=script_args.use_score_norm,
score_clip=script_args.score_clip,
)


Expand Down
5 changes: 4 additions & 1 deletion trl/models/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,10 @@ def add_and_load_reward_modeling_adapter(self, adapter_model_id, adapter_name="r
num_labels, hidden_dim = score_dict["weight"].shape
has_bias = any(["bias" in name for name in adapter_state_dict.keys()])

self.score = nn.Linear(hidden_dim, num_labels, bias=has_bias).to(self._get_current_device())
self.score = nn.Linear(hidden_dim, num_labels, bias=has_bias).to(
device=self._get_current_device(),
dtype=self.pretrained_model.dtype,
)
self.score.load_state_dict(score_dict)

# load the adapter to the model
Expand Down
8 changes: 7 additions & 1 deletion trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@

# There is a circular import in the PPOTrainer if we let isort sort these
# isort: off
from .utils import AdaptiveKLController, FixedKLController, ConstantLengthDataset, DataCollatorForCompletionOnlyLM
from .utils import (
AdaptiveKLController,
FixedKLController,
ConstantLengthDataset,
DataCollatorForCompletionOnlyLM,
RunningMoments,
)

# isort: on

Expand Down
3 changes: 3 additions & 0 deletions trl/trainer/ppo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ class PPOConfig(object):
ratio_threshold: Optional[float] = field(
default=10.0, metadata={"help": "Skip mini-batches with high PPO ratios that can cause loss spikes"}
)
use_score_scaling: Optional[bool] = field(default=False, metadata={"help": "Use score scaling"})
use_score_norm: Optional[bool] = field(default=False, metadata={"help": "Use score normalization"})
score_clip: Optional[float] = field(default=None, metadata={"help": "Score clipping"})

def __post_init__(self):
if self.forward_batch_size is not None:
Expand Down
28 changes: 19 additions & 9 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
)
from ..import_utils import is_torch_greater_2_0
from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper, create_reference_model
from . import AdaptiveKLController, BaseTrainer, FixedKLController, PPOConfig
from . import AdaptiveKLController, BaseTrainer, FixedKLController, PPOConfig, RunningMoments


MODEL_CARD_TEMPLATE = """---
Expand Down Expand Up @@ -338,6 +338,8 @@ def __init__(

PPODecorators.optimize_cuda_cache = self.config.optimize_cuda_cache

self.running = RunningMoments(self.accelerator)

def _filter_kwargs(self, kwargs, target_func):
"""
filter the keyword arguments that are supported by the target function.
Expand Down Expand Up @@ -382,7 +384,7 @@ def _set_signature_columns_if_needed(self):
signature = inspect.signature(self.model.forward)
self._signature_columns = list(signature.parameters.keys())
# label => sentiment | we need query and response for logging purpose
self._signature_columns += list(set(["label", "query", "response"]))
self._signature_columns += ["label", "query", "response"]

# Adapted from transformers.Trainer._remove_unused_columns
def _remove_unused_columns(self, dataset: "Dataset"):
Expand Down Expand Up @@ -582,11 +584,23 @@ def step(
bs = self.config.batch_size

queries, responses, scores = self._step_safety_checker(bs, queries, responses, scores)
scores = torch.tensor(scores)
if self.config.use_score_scaling:
# Score scaling
scores_mean, scores_std = self.running.update(scores)
if self.config.use_score_norm:
scores = (scores - self.running.mean) / self.running.std
else:
scores /= self.running.std

if self.config.score_clip is not None:
# Score clipping
scores = torch.clip(scores, -self.config.score_clip, self.config.score_clip)

# if we want to push best model to the hub
if hasattr(self, "highest_reward"):
if self.compare_step % self.config.compare_steps == 0:
curr_mean_reward = torch.tensor(scores).mean()
curr_mean_reward = scores.mean()
# if the best reward ever seen
if curr_mean_reward > self.highest_reward:
self.highest_reward = curr_mean_reward
Expand Down Expand Up @@ -1148,8 +1162,8 @@ def record_step_stats(self, kl_coef: float, **data):
mean_non_score_reward = masked_mean(
data["non_score_reward"], mask
) # non_score_reward is size `batch_size`, `response_length`
mean_scores = torch.stack(data["scores"]).mean() # scores is size `batch_size`
std_scores = torch.stack(data["scores"]).std()
mean_scores = data["scores"].mean() # scores is size `batch_size`
std_scores = data["scores"].std()

if mean_kl.item() < -1.0:
# warn users
Expand Down Expand Up @@ -1243,10 +1257,6 @@ def log_stats(
logs["env/reward_std"] = torch.std(rewards).cpu().numpy().item()
logs["env/reward_dist"] = rewards.cpu().numpy()

logs["env/reward_mean"] = torch.mean(rewards).cpu().numpy().item()
logs["env/reward_std"] = torch.std(rewards).cpu().numpy().item()
logs["env/reward_dist"] = rewards.cpu().numpy()

if self.config.log_with == "tensorboard":
# update the current step
self.current_step += 1
Expand Down
61 changes: 60 additions & 1 deletion trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import random
import warnings
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -461,6 +461,65 @@ def on_save(self, args, state, control, **kwargs):
os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))


class RunningMoments:
def __init__(self, accelerator):
"""
Calculates the running mean and standard deviation of a data stream. Reference:
https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L75
"""
self.mean = 0
self.std = 1
self.var = 1
self.count = 1e-24
self.accelerator = accelerator

@torch.no_grad()
def update(self, xs: torch.Tensor) -> Tuple[float, float]:
"""
Updates running moments from batch's moments computed across ranks
"""
if self.accelerator.use_distributed:
xs_mean, xs_var, xs_count = get_global_statistics(self.accelerator, xs)
else:
xs_count = xs.numel()
xs_var, xs_mean = torch.var_mean(xs, unbiased=False)
xs_mean, xs_var = xs_mean.float(), xs_var.float()

delta = xs_mean - self.mean
tot_count = self.count + xs_count

new_sum = xs_var * xs_count
# correct old_sum deviation accounting for the new mean
old_sum = self.var * self.count + delta**2 * self.count * xs_count / tot_count
tot_sum = old_sum + new_sum

self.mean += delta * xs_count / tot_count
self.var = tot_sum / tot_count
self.std = (self.var * tot_count / (tot_count - 1)).float().sqrt()
self.count = tot_count

return xs_mean.item(), (xs_var * xs_count / (xs_count - 1)).float().sqrt().item()


@torch.no_grad()
def get_global_statistics(accelerator, xs: torch.Tensor, mask=None, device="cpu") -> Tuple[float, float, int]:
"""
Computes element-wise mean and variance of the tensor across processes. Reference:
https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L57C1-L73C75
"""
xs = xs.to(accelerator.device)
sum_and_count = torch.tensor([xs.sum(), (xs.numel() if mask is None else mask.sum())], device=xs.device)
sum_and_count = accelerator.reduce(sum_and_count)
global_sum, count = sum_and_count
global_mean = global_sum / count

sum_var = torch.sum(((xs - global_mean) ** 2).mul(1 if mask is None else mask))
sum_var = accelerator.reduce(sum_var)
global_var = sum_var / count

return global_mean.to(device), global_var.to(device), count.to(device)


def compute_accuracy(eval_pred) -> Dict[str, float]:
predictions, labels = eval_pred
# Here, predictions is rewards_chosen and rewards_rejected.
Expand Down