Skip to content

Commit 43993eb

Browse files
committed
Support GRPO advantage estimate weighting
1 parent fe4b13b commit 43993eb

File tree

3 files changed

+35
-2
lines changed

3 files changed

+35
-2
lines changed

bergson/build.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from datetime import timedelta
44
from typing import cast
55

6+
import pandas as pd
67
import torch
78
import torch.distributed as dist
89
import torch.multiprocessing as mp
@@ -80,7 +81,7 @@ def worker(rank: int, world_size: int, cfg: IndexConfig, ds: Dataset | IterableD
8081
cfg.model,
8182
device_map=device_map,
8283
quantization_config=quantization_config,
83-
torch_dtype=dtype,
84+
dtype=dtype,
8485
revision=cfg.revision,
8586
)
8687
target_modules = None
@@ -91,7 +92,7 @@ def worker(rank: int, world_size: int, cfg: IndexConfig, ds: Dataset | IterableD
9192
peft_config.base_model_name_or_path, # type: ignore
9293
device_map=device_map,
9394
quantization_config=quantization_config,
94-
torch_dtype=dtype,
95+
dtype=dtype,
9596
revision=cfg.revision,
9697
)
9798

@@ -190,6 +191,20 @@ def dist_worker(rank: int, world_size: int, cfg: IndexConfig, ds: Dataset):
190191
dist.destroy_process_group()
191192

192193

194+
def estimate_advantage(ds: Dataset, cfg: IndexConfig):
195+
"""Group rollouts by prompt and estimate advantages."""
196+
assert isinstance(ds, Dataset), "Dataset required for advantage estimation"
197+
198+
df = ds.select_columns([cfg.data.prompt_column, cfg.data.reward_column]).to_pandas()
199+
df = assert_type(pd.DataFrame, df)
200+
201+
advantages = df[cfg.data.reward_column] - df.groupby(cfg.data.prompt_column)[
202+
cfg.data.reward_column
203+
].transform("mean")
204+
205+
return advantages.tolist()
206+
207+
193208
def build_gradient_dataset(cfg: IndexConfig):
194209
# In many cases the token_batch_size may be smaller than the max length allowed by
195210
# the model. If cfg.data.truncation is True, we use the tokenizer to truncate
@@ -206,6 +221,13 @@ def build_gradient_dataset(cfg: IndexConfig):
206221
fn_kwargs=dict(args=cfg.data, tokenizer=tokenizer),
207222
remove_columns=remove_columns,
208223
)
224+
if cfg.data.reward_column:
225+
ds = ds.add_column(
226+
"advantage",
227+
estimate_advantage(ds, cfg),
228+
new_fingerprint="advantage", # type: ignore
229+
)
230+
209231
world_size = torch.cuda.device_count()
210232
if world_size <= 1:
211233
# Run the worker directly if no distributed training is needed. This is great

bergson/collection.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ def callback(name: str, g: torch.Tensor):
107107
# Compute average KL across all unmasked tokens
108108
kls = torch.sum(ft_lps.exp() * (ft_lps - ref_lps), dim=-1)
109109
losses = torch.sum(kls * masks, dim=-1) / denoms
110+
if "advantage" in batch:
111+
losses *= torch.tensor(batch["advantage"], device=losses.device)
112+
110113
losses.mean().backward()
111114
else:
112115
with collector:
@@ -118,6 +121,9 @@ def callback(name: str, g: torch.Tensor):
118121
reduction="none",
119122
).reshape_as(y[:, 1:])
120123
losses = losses.sum(1) / denoms
124+
if "advantage" in batch:
125+
losses *= torch.tensor(batch["advantage"], device=losses.device)
126+
121127
losses.mean().backward()
122128

123129
# Weirdly you need to explicitly synchronize here in order to make sure that

bergson/data.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ class DataConfig:
4141
conversation_column: str = ""
4242
"""Optional column in the dataset that contains the conversation."""
4343

44+
reward_column: str = ""
45+
"""Optional column in the dataset that contains the rewards.
46+
When specified, gradients are calculated using the policy
47+
gradient loss from Dr. GRPO. https://arxiv.org/abs/2503.20783"""
48+
4449
truncation: bool = False
4550
"""Whether to truncate long documents to fit the token budget."""
4651

0 commit comments

Comments
 (0)