33from datetime import timedelta
44from typing import cast
55
6+ import pandas as pd
67import torch
78import torch .distributed as dist
89import torch .multiprocessing as mp
@@ -80,7 +81,7 @@ def worker(rank: int, world_size: int, cfg: IndexConfig, ds: Dataset | IterableD
8081 cfg .model ,
8182 device_map = device_map ,
8283 quantization_config = quantization_config ,
83- torch_dtype = dtype ,
84+ dtype = dtype ,
8485 revision = cfg .revision ,
8586 )
8687 target_modules = None
@@ -91,7 +92,7 @@ def worker(rank: int, world_size: int, cfg: IndexConfig, ds: Dataset | IterableD
9192 peft_config .base_model_name_or_path , # type: ignore
9293 device_map = device_map ,
9394 quantization_config = quantization_config ,
94- torch_dtype = dtype ,
95+ dtype = dtype ,
9596 revision = cfg .revision ,
9697 )
9798
@@ -190,6 +191,20 @@ def dist_worker(rank: int, world_size: int, cfg: IndexConfig, ds: Dataset):
190191 dist .destroy_process_group ()
191192
192193
194+ def estimate_advantage (ds : Dataset , cfg : IndexConfig ):
195+ """Group rollouts by prompt and estimate advantages."""
196+ assert isinstance (ds , Dataset ), "Dataset required for advantage estimation"
197+
198+ df = ds .select_columns ([cfg .data .prompt_column , cfg .data .reward_column ]).to_pandas ()
199+ df = assert_type (pd .DataFrame , df )
200+
201+ advantages = df [cfg .data .reward_column ] - df .groupby (cfg .data .prompt_column )[
202+ cfg .data .reward_column
203+ ].transform ("mean" )
204+
205+ return advantages .tolist ()
206+
207+
193208def build_gradient_dataset (cfg : IndexConfig ):
194209 # In many cases the token_batch_size may be smaller than the max length allowed by
195210 # the model. If cfg.data.truncation is True, we use the tokenizer to truncate
@@ -206,6 +221,13 @@ def build_gradient_dataset(cfg: IndexConfig):
206221 fn_kwargs = dict (args = cfg .data , tokenizer = tokenizer ),
207222 remove_columns = remove_columns ,
208223 )
224+ if cfg .data .reward_column :
225+ ds = ds .add_column (
226+ "advantage" ,
227+ estimate_advantage (ds , cfg ),
228+ new_fingerprint = "advantage" , # type: ignore
229+ )
230+
209231 world_size = torch .cuda .device_count ()
210232 if world_size <= 1 :
211233 # Run the worker directly if no distributed training is needed. This is great
0 commit comments