-
Notifications
You must be signed in to change notification settings - Fork 40
/
Copy pathgrpo.py
223 lines (204 loc) · 8.16 KB
/
grpo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import dataclasses
import gc
import math
from collections import defaultdict
from typing import Callable, List
import numpy as np
import torch
from data_types import Episode, MiniBatch
from qwen2_model import Transformer
from tokenizer import Tokenizer
@torch.no_grad()
def rollout(
model: Transformer,
batch: MiniBatch,
tokenizer: Tokenizer,
max_gen_len: int,
num_answer_per_question: int,
reward_function: Callable,
device: torch.device,
dtype: torch.dtype,
) -> List[Episode]:
end_token = tokenizer.eos_token
end_token_id = tokenizer.eos_token_id
pad_token_id = tokenizer.pad_token_id
prefix_token_ids = batch.prefix_token_ids
bsz = len(batch.prefix) * num_answer_per_question
min_prompt_len = min(len(t) for t in prefix_token_ids)
max_prompt_len = max(len(t) for t in prefix_token_ids)
total_len = max_gen_len + max_prompt_len
model.init_kv_cache(
max_batch_size=bsz,
max_seq_len=total_len,
device=device,
dtype=dtype,
)
tokens = torch.full((bsz, total_len), pad_token_id, dtype=torch.long, device=device)
for k, t in enumerate(prefix_token_ids):
offset = k * num_answer_per_question
for i in range(num_answer_per_question):
tokens[offset + i, : len(t)] = torch.tensor(
t, dtype=torch.long, device=device
)
prev_pos = 0
input_text_mask = tokens != pad_token_id
assert min_prompt_len < total_len
is_finished = torch.zeros((bsz,), dtype=torch.bool, device=device)
for cur_pos in range(min_prompt_len, total_len):
print(
f"\r* Generating trajectories: {cur_pos-min_prompt_len:>4d}/{total_len-min_prompt_len:>4d}",
flush=True,
end="",
)
with torch.autocast(device_type=device.type, dtype=dtype):
logits = model.inference(tokens[:, prev_pos:cur_pos], prev_pos)
probs = torch.softmax(logits[:, -1], dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
next_token = next_token.reshape(-1)
next_token = torch.where(
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
# if an rollout is finished, we fill the rest of the tokens with pad_token_id
next_token = torch.where(is_finished, pad_token_id, next_token)
tokens[:, cur_pos] = next_token
if end_token_id is not None:
is_end_token = next_token == end_token_id
is_generated_token = ~input_text_mask[:, cur_pos]
is_finished = is_finished | (is_end_token & is_generated_token)
prev_pos = cur_pos
if is_finished.all():
break
model.del_kv_cache()
gc.collect()
torch.cuda.empty_cache()
is_finished_list = is_finished.tolist()
tokens_list = tokens.tolist()
# prepare the output episodes
episodes = []
for i in range(bsz // num_answer_per_question):
for j in range(num_answer_per_question):
idx = i * num_answer_per_question + j
generated_token_ids = tokens_list[idx][len(batch.prefix_token_ids[i]) :]
# remove padding tokens
if pad_token_id in generated_token_ids:
generated_token_ids = generated_token_ids[
: generated_token_ids.index(pad_token_id)
]
generated_text = tokenizer.detokenize(generated_token_ids)
rewards = reward_function(
response=generated_text,
numbers=batch.numbers[i],
target=batch.target[i],
end_token=end_token,
)
episode = Episode(
prefix=batch.prefix[i],
text=batch.prefix[i] + generated_text,
prefix_token_ids=batch.prefix_token_ids[i],
prefix_tokens=batch.prefix_tokens[i],
generated_token_ids=generated_token_ids,
is_finished=is_finished_list[idx],
reward=rewards["reward"],
reward_info=rewards["reward_info"],
)
episodes.append(episode)
# clear the output line
print("\r", end=" " * 100, flush=True)
return episodes
def normalize_rewards_per_group(episodes: List[Episode]) -> List[Episode]:
"""Normalize rewards per group. A group is defined by the prefix."""
groups = defaultdict(list)
for episode in episodes:
groups[tuple(episode.prefix)].append(episode)
output = []
for group in groups.values():
group_rewards = [item.reward for item in group]
mean_reward = np.mean(group_rewards)
std_reward = np.std(group_rewards)
for episode in group:
normalized_reward = (episode.reward - mean_reward) / (std_reward + 1e-4)
episode = dataclasses.replace(episode, reward=normalized_reward)
output.append(episode)
return output
def compute_entropy(logits: torch.Tensor) -> torch.Tensor:
probs = torch.nn.functional.softmax(logits, dim=-1)
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(probs * logits, dim=-1)
return entropy
def update_policy(
model,
optimizer,
episodes: List[Episode],
micro_batch_size: int,
pad_token_id: int,
max_grad_norm: float,
device: torch.device,
dtype: torch.dtype,
):
"""Update the policy using the GRPO algorithm."""
episodes = normalize_rewards_per_group(episodes)
# sort episodes by token length for efficient (micro-)batching
episodes.sort(key=lambda x: len(x.prefix_token_ids) + len(x.generated_token_ids))
num_micro_batches = math.ceil(len(episodes) / micro_batch_size)
num_target_tokens = sum(len(episode.generated_token_ids) for episode in episodes)
entropy = 0.0
for i in range(0, len(episodes), micro_batch_size):
print(
f"\r* Computing policy gradient: {i:>2d}/{len(episodes):>2d}",
flush=True,
end="",
)
j = min(i + micro_batch_size, len(episodes))
batch_episodes = episodes[i:j]
batch_lengths = [
len(episode.prefix_token_ids) + len(episode.generated_token_ids)
for episode in batch_episodes
]
batch_max_length = max(batch_lengths)
batch_token_ids = [
episode.prefix_token_ids
+ episode.generated_token_ids
+ [pad_token_id] * (batch_max_length - batch_lengths[i])
for i, episode in enumerate(batch_episodes)
]
batch_masks = [
[0] * len(episode.prefix_token_ids)
+ [1] * len(episode.generated_token_ids)
+ [0] * (batch_max_length - batch_lengths[i])
for i, episode in enumerate(batch_episodes)
]
batch_advantages = [episode.reward for episode in batch_episodes]
batch_token_ids = torch.tensor(batch_token_ids, device=device, dtype=torch.long)
batch_masks = torch.tensor(batch_masks, device=device, dtype=torch.bool)
batch_advantages = torch.tensor(
batch_advantages, device=device, dtype=torch.float32
)
with torch.autocast(device_type=device.type, dtype=dtype):
input_token_ids = batch_token_ids[:, :-1]
target_token_ids = batch_token_ids[:, 1:]
target_masks = batch_masks[:, 1:]
logits = model.forward(input_token_ids).float()
log_probs = -torch.nn.functional.cross_entropy(
logits.reshape(-1, logits.size(-1)),
target_token_ids.reshape(-1),
ignore_index=pad_token_id,
reduction="none",
).reshape(input_token_ids.shape[0], -1)
with torch.no_grad():
token_entropy = compute_entropy(logits)
entropy = entropy + (token_entropy * target_masks).sum() / num_target_tokens
obj = log_probs * batch_advantages[:, None]
# per-token objective
obj = (obj * target_masks).sum() / num_target_tokens
loss = -obj
loss.backward()
# update the policy
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=max_grad_norm
)
optimizer.step()
optimizer.zero_grad(set_to_none=True)
return {
"loss": loss.item(),
"grad_norm": grad_norm.item(),
"entropy": entropy.item(),
}