-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrlhf_preference_comparisons.py
175 lines (150 loc) · 6.84 KB
/
rlhf_preference_comparisons.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
"""
@author: elamrani
"""
from typing import (
Optional,
Union,
Dict
)
import numpy as np
from stable_baselines3.common import type_aliases
from imitation.util import util
from rlhf_reward_model import RewardModel
from rlhf_preference_dataset import PreferenceDataset, PreferenceDatasetNoDiscard
from rlhf_reward_trainer import RewardTrainer
from rlhf_preference_gatherer import PreferenceGatherer
from rlhf_pair_generator import PairGenerator
QUERY_SCHEDULES: Dict[str, type_aliases.Schedule] = {
"constant": lambda t: 1.0,
"hyperbolic": lambda t: 1.0 / (1.0 + t),
"inverse_quadratic": lambda t: 1.0 / (1.0 + t**2),
}
class PreferenceComparisons():
"""Main interface for reward learning using preference comparisons."""
def __init__(
self,
gym,
reward_model: RewardModel,
num_iterations: int,
pair_generator: Optional[PairGenerator] = None,
preference_gatherer: Optional[PreferenceGatherer] = None,
reward_trainer: Optional[RewardTrainer] = None,
comparison_queue_size: Optional[int] = None,
transition_oversampling = 1,
initial_comparison_frac: float = 0.1,
initial_epoch_multiplier: float = 20.0, # used to be 200
query_schedule: Union[str, type_aliases.Schedule] = "hyperbolic",
draw_freq = 100,
use_wandb = False,
logger = None,
dataset_path = None,
device = None,
human_fb = False
):
# Init all attributes
self.model = reward_model
self.reward_trainer = reward_trainer
self.preference_gatherer = preference_gatherer
self.pair_generator = pair_generator
self.initial_comparison_frac = initial_comparison_frac
self.initial_epoch_multiplier = initial_epoch_multiplier
self.num_iterations = num_iterations
self.transition_oversampling = transition_oversampling
# self.draw_freq = draw_freq # draw_freq = 1 when asking for human feedback
self.use_wandb = use_wandb
self.logger = logger
self.human_fb = human_fb
# Init schedule
if callable(query_schedule):
self.query_schedule = query_schedule
elif query_schedule in QUERY_SCHEDULES:
self.query_schedule = QUERY_SCHEDULES[query_schedule]
else:
raise ValueError(f"Unknown query schedule: {query_schedule}")
# Init preference dataset
self.dataset = PreferenceDatasetNoDiscard(max_size=comparison_queue_size, device=device)
self.dataset_path = dataset_path
# Init gym
self.gym = gym
def train(
self,
total_timesteps: int,
total_comparisons: int
):
"""Train the reward model and the policy if applicable.
Args:
total_timesteps: number of environment interaction steps
total_comparisons: number of preferences to gather in total
Returns:
A dictionary with final metrics such as loss and accuracy
of the reward model.
"""
initial_comparisons = int(total_comparisons * self.initial_comparison_frac)
total_comparisons -= initial_comparisons
# Compute the number of comparisons to request at each iteration in advance (with schedule).
vec_schedule = np.vectorize(self.query_schedule)
unnormalized_probs = vec_schedule(np.linspace(0, 1, self.num_iterations))
probs = unnormalized_probs / np.sum(unnormalized_probs)
shares = util.oric(probs * total_comparisons)
schedule = [initial_comparisons] + shares.tolist()
self.logger.info(f"Query schedule: {schedule}")
timesteps_per_iteration, extra_timesteps = divmod(
total_timesteps,
self.num_iterations,
)
# MAIN LOOP
for i, num_pairs in enumerate(schedule):
self.logger.info(f"\n \n ROUND {i}")
#############################################
# Generate trajectories with trained policy #
#############################################
# Generate trajectories
nb_traj = self.transition_oversampling * 2 * num_pairs
self.logger.info(f"Collecting {nb_traj} trajectories")
trajectories, success_rate, _ = self.gym.generate_trajectories(nb_traj, draw=self.human_fb)
self.logger.debug(f"Nb of trajectories generated: {len(trajectories)}")
self.logger.info(f"Success rate: {success_rate}")
# Create pairs of trajectories (to be compared)
self.logger.info("Creating trajectory pairs")
pairs = self.pair_generator(trajectories, num_pairs, self.transition_oversampling) # oversample when disagreemt
self.logger.debug("Pair formation done")
##########################
# Gather new preferences #
##########################
# Gather synthetic or human preferences
self.logger.info("Gathering preferences")
pairs, preferences = self.preference_gatherer(pairs)
print(f"Pairs: {len(pairs)}")
print(f"Preferences: {preferences}")
self.logger.debug("Gathering over")
self.logger.debug(f"Preferences gathered: {preferences}")
# Store preferences in Preference Dataset
self.dataset.push(pairs, preferences)
self.logger.info(f"Dataset now contains {len(self.dataset)} comparisons")
##########################
# Train the reward model #
##########################
# On the first iteration, we train the reward model for longer,
# as specified by initial_epoch_multiplier.
epoch_multip = 1.0
if i == 0:
epoch_multip = self.initial_epoch_multiplier # default: 200
self.logger.info("\n Training reward model")
self.reward_trainer.train(self.dataset, epoch_multiplier=epoch_multip)
self.logger.debug("Reward training finished")
###################
# Train the agent #
###################
num_steps = timesteps_per_iteration
# if the number of timesteps per iterations doesn't exactly divide
# the desired total number of timesteps, we train the agent a bit longer
# at the end of training (where the reward model is presumably best)
if i == self.num_iterations - 1:
num_steps += extra_timesteps
self.logger.info("\n Training agent")
self.gym.training(nb_episodes=1000, rlhf=self.use_wandb)
self.logger.debug("Training finished")
# if human feedback, save preferences
if self.dataset_path:
self.logger.info("\n Preference dataset saved")
self.dataset.save(self.dataset_path)