Skip to content

Commit 50b96e2

Browse files
Fix CI experimental tests TypeError for GRPOWithReplayBufferTrainer.update_with_replay_buffer (#4366)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
1 parent 3d718df commit 50b96e2

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/experimental/test_grpo_with_replay_buffer_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def _make_inputs(self, group_advantages, with_pixels=False, with_logprobs=False)
140140
"prompt_mask": torch.ones(4, 2, dtype=torch.long),
141141
"completion_ids": torch.tensor([[9, 10], [11, 12], [13, 14], [15, 16]]),
142142
"completion_mask": torch.ones(4, 2, dtype=torch.long),
143-
"prompt_inputs": {"pixel_values": torch.randn(4, 3, 224, 224)} if with_pixels else {},
143+
"forward_kwargs": {"pixel_values": torch.randn(4, 3, 224, 224)} if with_pixels else {},
144144
"old_per_token_logps": torch.randn(4, 2) if with_logprobs else None,
145145
}
146146
inputs["group_std_rewards"] = group_advantages.std(dim=1).expand_as(group_advantages)
@@ -217,7 +217,7 @@ def test_update_with_inputs_different_seq_len(self):
217217
]
218218
),
219219
"completion_mask": torch.tensor([[1, 1, 0], [1, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.long),
220-
"prompt_inputs": {},
220+
"forward_kwargs": {},
221221
}
222222
inputs["group_std_rewards"] = group_advantages.std(dim=1).expand_as(group_advantages)
223223

0 commit comments

Comments
 (0)