Skip to content

Commit 43b6541

Browse files
SolarWindRideralbertvillanovaqgallouedec
authored
Support completion bootstrap for VLM in GRPO/RLOO (#4452)
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
1 parent 642b721 commit 43b6541

File tree

4 files changed

+12
-0
lines changed

4 files changed

+12
-0
lines changed

trl/experimental/gfpo/gfpo_trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,9 @@ def _generate_and_score_completions(self, inputs):
206206
completions = []
207207
for prompt, completion in zip(prompts, completions_text, strict=True):
208208
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
209+
if isinstance(bootstrap, list): # for VLM, the format might be [{"type": "text", "text": "..."}]
210+
assert len(bootstrap) == 1 and bootstrap[0]["type"] == "text"
211+
bootstrap = bootstrap[0]["text"]
209212
completions.append([{"role": "assistant", "content": bootstrap + completion}])
210213
else:
211214
completions = completions_text

trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,9 @@ def _generate_and_score_completions(
210210
completions = []
211211
for prompt, completion in zip(prompts, completions_text, strict=True):
212212
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
213+
if isinstance(bootstrap, list): # for VLM, the format might be [{"type": "text", "text": "..."}]
214+
assert len(bootstrap) == 1 and bootstrap[0]["type"] == "text"
215+
bootstrap = bootstrap[0]["text"]
213216
completions.append([{"role": "assistant", "content": bootstrap + completion}])
214217
else:
215218
completions = completions_text

trl/trainer/grpo_trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1575,6 +1575,9 @@ def _generate_and_score_completions(
15751575
completions = []
15761576
for prompt, completion in zip(prompts, completions_text, strict=True):
15771577
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
1578+
if isinstance(bootstrap, list): # for VLM, the format might be [{"type": "text", "text": "..."}]
1579+
assert len(bootstrap) == 1 and bootstrap[0]["type"] == "text"
1580+
bootstrap = bootstrap[0]["text"]
15781581
completions.append([{"role": "assistant", "content": bootstrap + completion}])
15791582
else:
15801583
completions = completions_text

trl/trainer/rloo_trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,6 +1342,9 @@ def _generate_and_score_completions(
13421342
completions = []
13431343
for prompt, completion in zip(prompts, completions_text, strict=True):
13441344
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
1345+
if isinstance(bootstrap, list): # for VLM, the format might be [{"type": "text", "text": "..."}]
1346+
assert len(bootstrap) == 1 and bootstrap[0]["type"] == "text"
1347+
bootstrap = bootstrap[0]["text"]
13451348
completions.append([{"role": "assistant", "content": bootstrap + completion}])
13461349
else:
13471350
completions = completions_text

0 commit comments

Comments
 (0)