Skip to content

Commit 4266550

Browse files
Slice std_rewards
1 parent d032423 commit 4266550

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def _generate_and_score_completions(
270270
)
271271
all_process_advantages = advantages.clone() # keep the aggregated advantages for logging
272272
advantages = advantages[process_slice]
273+
std_rewards = std_rewards[process_slice]
273274

274275
# Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
275276
for i, reward_func_name in enumerate(self.reward_func_names):

0 commit comments

Comments
 (0)