Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
112 commits
Select commit Hold shift + click to select a range
552e899
Refactor image handling: replace `image_split_sizes` with `image_grid…
qgallouedec Sep 19, 2025
449ef07
simpler
qgallouedec Sep 19, 2025
c8933aa
gfpo
qgallouedec Sep 19, 2025
229c554
multi-image grpo
qgallouedec Sep 19, 2025
3ca6ad5
log with wandb
qgallouedec Sep 19, 2025
dcf4b92
no vlm reward models
qgallouedec Sep 20, 2025
30ad7ca
rloo
qgallouedec Sep 20, 2025
86cc30b
gfpo
qgallouedec Sep 20, 2025
088897b
fix
qgallouedec Sep 20, 2025
d2adc63
test peft
qgallouedec Sep 20, 2025
f4c82bf
fix gfpo
qgallouedec Sep 20, 2025
1257796
rloo test
qgallouedec Sep 20, 2025
099a39b
peft rloo
qgallouedec Sep 20, 2025
529add6
oops
qgallouedec Sep 20, 2025
fc6b11f
update test
qgallouedec Sep 20, 2025
ae1f497
generate method
qgallouedec Sep 20, 2025
f998432
debug
qgallouedec Sep 20, 2025
fa73876
skip failing test
qgallouedec Sep 20, 2025
52d8bd9
Merge branch 'main' into drop-image_split_sizes
qgallouedec Sep 20, 2025
dfc0d38
Merge branch 'drop-image_split_sizes' into multi-image-support
qgallouedec Sep 20, 2025
fc52e68
test fixed!
qgallouedec Sep 20, 2025
4d12aeb
Merge branch 'multi-image-support' into generate-method
qgallouedec Sep 20, 2025
4fc2b5b
gfpo
qgallouedec Sep 20, 2025
b628744
rm vllm
qgallouedec Sep 20, 2025
d3a769f
fix doc
qgallouedec Sep 20, 2025
e17ec42
Merge branch 'main' into drop-image_split_sizes
qgallouedec Sep 22, 2025
efbb03a
Merge branch 'drop-image_split_sizes' into multi-image-support
qgallouedec Sep 22, 2025
562c662
Merge branch 'main' into multi-image-support
qgallouedec Sep 22, 2025
485781c
Merge branch 'main' into multi-image-support
qgallouedec Sep 22, 2025
05270f8
update layers to ignore
qgallouedec Sep 22, 2025
1c53094
clarify image column desc
qgallouedec Sep 22, 2025
9b6652e
rm VLM x RM warning
qgallouedec Sep 23, 2025
c500440
Merge branch 'multi-image-support' into generate-method
qgallouedec Sep 23, 2025
a6a8c44
Merge branch 'main' into generate-method
qgallouedec Sep 23, 2025
d8665e1
Merge branch 'main' into generate-method
qgallouedec Sep 23, 2025
365d501
Merge branch 'main' into generate-method
qgallouedec Sep 23, 2025
cdb4c76
Merge branch 'main' into generate-method
qgallouedec Sep 24, 2025
c83e710
same for rloo
qgallouedec Sep 24, 2025
ec6ad25
nits style and align
qgallouedec Sep 24, 2025
b4cadde
Merge branch 'main' into generate-method
qgallouedec Sep 24, 2025
b0dceb9
restart
qgallouedec Sep 25, 2025
ebe32c2
progress
qgallouedec Sep 25, 2025
0213662
progress continues
qgallouedec Sep 25, 2025
8b3a724
progress again again
qgallouedec Sep 25, 2025
c1ae6aa
back to working point
qgallouedec Sep 25, 2025
1a66b43
revert chage data utils
qgallouedec Sep 25, 2025
2dc69a6
Merge branch 'main' into generate-method
qgallouedec Sep 26, 2025
9435a94
refactor in grpo
qgallouedec Sep 26, 2025
d3f1d3c
Merge branch 'main' into refactor_generate
qgallouedec Sep 26, 2025
3d8ea27
wrong merge commit
qgallouedec Sep 26, 2025
27dc958
fix num_input_tokens_seen
qgallouedec Sep 26, 2025
53772ef
getting closer
qgallouedec Sep 26, 2025
8766fa5
consistent naming
qgallouedec Sep 26, 2025
236b78b
better
qgallouedec Sep 26, 2025
9da4830
simplify a bit + comment
qgallouedec Sep 26, 2025
b3bd0b0
another one
qgallouedec Sep 26, 2025
d79b9e1
get prompt ids from generation
qgallouedec Sep 26, 2025
8d34d54
remove pad token removal
qgallouedec Sep 26, 2025
e770efe
Merge branch 'refactor_generate' into refactor_generate_2
qgallouedec Sep 26, 2025
0e2ae34
rely on generator for prompt truncation
qgallouedec Sep 26, 2025
46d8eb7
revert
qgallouedec Sep 26, 2025
11acc75
rm enforce eager
qgallouedec Sep 26, 2025
acee7d8
rm truncate_with_protected_tokens
qgallouedec Sep 26, 2025
0b5865e
ensure proper truncation and side
qgallouedec Sep 26, 2025
d8af003
rm useless comment
qgallouedec Sep 26, 2025
fc263a3
rm imports
qgallouedec Sep 26, 2025
35f99fd
requires padding
qgallouedec Sep 26, 2025
8149d05
rm truncation test
qgallouedec Sep 26, 2025
55a2480
rloo + doc
qgallouedec Sep 26, 2025
c8041e1
Merge branch 'refactor_generate' into refactor_generate_2
qgallouedec Sep 26, 2025
b8c0c9b
Merge branch 'refactor_generate_2' into refactor_generate_3
qgallouedec Sep 26, 2025
7b7a11d
test and doc
qgallouedec Sep 27, 2025
c5064d6
gfpo
qgallouedec Sep 27, 2025
effb41b
Merge branch 'main' into refactor_generate
qgallouedec Sep 27, 2025
e82bfb4
Merge branch 'main' into refactor_generate
qgallouedec Sep 27, 2025
4b9c126
Merge branch 'refactor_generate' into refactor_generate_2
qgallouedec Sep 27, 2025
3f02702
Merge branch 'refactor_generate_2' into refactor_generate_3
qgallouedec Sep 27, 2025
f11759e
Merge branch 'main' into refactor_generate_2
qgallouedec Sep 30, 2025
e7aa945
fix vllm client server
qgallouedec Sep 30, 2025
e164ec5
repicate all_prompt_ids
qgallouedec Oct 1, 2025
49577ad
Same for RLOO
qgallouedec Oct 1, 2025
5fca5b8
fix normal generation path
qgallouedec Oct 1, 2025
5cc6af5
Merge branch 'refactor_generate_2' into refactor_generate_3
qgallouedec Oct 1, 2025
4dce145
remove vision tokens
qgallouedec Oct 1, 2025
ddfd3b5
same for rloo
qgallouedec Oct 1, 2025
c434fa2
truncation_side=left
qgallouedec Oct 1, 2025
377b081
rm test_training_vlm_and_prompt_truncation
qgallouedec Oct 1, 2025
d599c20
Merge branch 'main' into refactor_generate_2
qgallouedec Oct 1, 2025
e82db74
🔣 Fix test: replace `trainer.tokenizer` by `trainer.processing_class`…
qgallouedec Oct 1, 2025
192deb3
Fix CI ImportError: FlashAttention2 and decorator order for all param…
albertvillanova Oct 1, 2025
cf9d8e7
Hotfix wrong formatting of docstrings with blockquote tips (#4187)
albertvillanova Oct 1, 2025
f9c3c3c
🌡️ Have vLLM return processed (temperature scaled) log probs (#4163)
YonatanGideoni Oct 1, 2025
6489479
Replace remaining trainer.tokenizer with trainer.processing_class in …
albertvillanova Oct 3, 2025
21a67fc
[DOCS] Lora without regret (#4181)
burtenshaw Oct 3, 2025
c1e7ad2
[DOCS/FIX] lora without regrets - fix lr (#4207)
burtenshaw Oct 6, 2025
5d34144
Remove custome_container for building the docs (#4198)
albertvillanova Oct 6, 2025
ae2a0e7
Remove tokenizer creation from `sft` example script (#4197)
sergiopaniego Oct 6, 2025
6543f51
Hotfix: Exclude transformers 4.57.0 for Python 3.9 (#4209)
albertvillanova Oct 6, 2025
8319ce0
Replace unittest with pytest (#4188)
albertvillanova Oct 6, 2025
4fdaa4c
Updated vLLM integration guide (#4162)
sergiopaniego Oct 6, 2025
d258e36
Remove `Optional` from `processing_class` in `PPOTrainer` (#4212)
sergiopaniego Oct 6, 2025
7f5b499
Replace setup with pyproject and fix packaging unintended modules (#4…
albertvillanova Oct 6, 2025
df386f9
Merge branch 'main' into refactor_generate_2
qgallouedec Oct 6, 2025
5b9a6ab
Merge branch 'main' into refactor_generate_2
qgallouedec Oct 6, 2025
766bbce
Merge branch 'refactor_generate_2' into refactor_generate_3
qgallouedec Oct 6, 2025
4a274d5
Merge branch 'main' into refactor_generate_2
qgallouedec Oct 6, 2025
db552be
Merge branch 'refactor_generate_2' into refactor_generate_3
qgallouedec Oct 6, 2025
a84325c
style
qgallouedec Oct 6, 2025
ee03478
remove test case for prompt truncation
qgallouedec Oct 7, 2025
45290c9
Merge branch 'main' into refactor_generate_3
qgallouedec Oct 7, 2025
78132bf
Merge branch 'main' into refactor_generate_3
qgallouedec Oct 10, 2025
f79aba1
Merge branch 'main' into refactor_generate_3
qgallouedec Oct 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 0 additions & 41 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1471,47 +1471,6 @@ def reward_func(completions, **kwargs):
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

@require_vision
def test_training_vlm_and_prompt_truncation(self):
# If not handled properly, prompt truncation may truncate image token
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")

def reward_func(completions, **kwargs):
"""Reward function that rewards longer completions."""
return [float(len(completion[0]["content"])) for completion in completions]

training_args = GRPOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=8, # reduce the completion length to reduce memory usage
max_prompt_length=18,
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

assert trainer.state.log_history[-1]["train_loss"] is not None

# Check that the params have changed
# Because of the way the tiny models are initialized, the gradient does not flow properly through the
# vision parts of the model, so we skip them. Ideally, we should fix the init of these models.
params_to_skip = ("model.visual.",)
for n, param in previous_trainable_params.items():
if n.startswith(params_to_skip):
continue
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

@parameterized.expand(
[
("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",),
Expand Down
41 changes: 0 additions & 41 deletions tests/test_rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,47 +1212,6 @@ def reward_func(completions, **kwargs):
elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer)
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed."

@require_vision
def test_training_vlm_and_prompt_truncation(self):
# If not handled properly, prompt truncation may truncate image token
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")

def reward_func(completions, **kwargs):
"""Reward function that rewards longer completions."""
return [float(len(completion[0]["content"])) for completion in completions]

training_args = RLOOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=8, # reduce the completion length to reduce memory usage
max_prompt_length=18,
report_to="none",
)
trainer = RLOOTrainer(
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

assert trainer.state.log_history[-1]["train_loss"] is not None

# Check that the params have changed
# Because of the way the tiny models are initialized, the gradient does not flow properly through the
# vision parts of the model, so we skip them. Ideally, we should fix the init of these models.
params_to_skip = ("model.visual.",)
for n, param in previous_trainable_params.items():
if n.startswith(params_to_skip):
continue
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

@parameterized.expand(
[
("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",),
Expand Down
79 changes: 0 additions & 79 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
shuffle_sequence_dict,
split_pixel_values_by_grid,
split_tensor_dict,
truncate_with_protected_tokens,
unsplit_pixel_values_by_grid,
)

Expand Down Expand Up @@ -1009,84 +1008,6 @@ def test_multi_images(self):
assert torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2], [1, 2, 1]]))


class TestTruncateWithProtectedTokens(TrlTestCase):
def test_basic_example(self):
"""Test the basic example from the problem description."""
prompt_ids = [1, 2, 3, 4, 5]
protected_tokens = [2, 3]
target_length = 3

new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)

expected_ids = [2, 3, 5]
assert new_ids == expected_ids

def test_no_truncation_needed(self):
"""Test when target length equals current length."""
prompt_ids = [1, 2, 3]
protected_tokens = [2]
target_length = 3

new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)

assert new_ids == prompt_ids

def test_no_protected_tokens(self):
"""Test truncation with no protected tokens (normal right truncation)."""
prompt_ids = [1, 2, 3, 4, 5]
protected_tokens = []
target_length = 3

new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)

expected_ids = [3, 4, 5] # Last 3 tokens
assert new_ids == expected_ids

def test_all_tokens_protected(self):
"""Test when all remaining tokens are protected."""
prompt_ids = [1, 2, 3, 4, 5]
protected_tokens = [3, 4, 5]
target_length = 3

new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)

expected_ids = [3, 4, 5]
assert new_ids == expected_ids

def test_too_many_protected_tokens(self):
"""Test error when too many protected tokens for target length."""
prompt_ids = [1, 2, 3, 4, 5]
protected_tokens = [1, 2, 3, 4]
target_length = 3

with pytest.raises(ValueError):
truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)

def test_single_batch_single_token(self):
"""Test edge case with single batch and single token."""
prompt_ids = [5]
protected_tokens = [5]
target_length = 1

new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)

assert new_ids == prompt_ids

def test_order_preservation(self):
"""Test that relative order is preserved."""
prompt_ids = [10, 2, 20, 3, 30, 40]
protected_tokens = [2, 3]
target_length = 4

new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)

# Should keep protected tokens 2, 3 and last 2 non-protected tokens 30, 40
# Order should be: 2, 3, 30, 40 (maintaining original relative positions)
expected_ids = [2, 3, 30, 40]

assert new_ids == expected_ids


class TestUnsplitPixelValuesByGrid(TrlTestCase):
def test_unsplit_correctly(self):
pixel_values = [torch.randn(4, 5), torch.randn(2, 5)]
Expand Down
6 changes: 6 additions & 0 deletions trl/extras/vllm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def generate(
top_k: int = -1,
min_p: float = 0.0,
max_tokens: int = 16,
truncate_prompt_tokens: Optional[int] = None,
guided_decoding_regex: Optional[str] = None,
generation_kwargs: Optional[dict] = None,
) -> list[list[int]]:
Expand All @@ -207,6 +208,10 @@ def generate(
Minimum probability for sampling.
max_tokens (`int`, *optional*, defaults to `16`):
Maximum number of tokens to generate for each prompt.
truncate_prompt_tokens (`int`, *optional*):
If set to `-1`, will use the truncation size supported by the model. If set to an integer k, will use
only the last k tokens from the prompt (i.e., left truncation). If set to `None`, truncation is
disabled.
guided_decoding_regex (`str`, *optional*):
Regular expression to guide the decoding process.
generation_kwargs (`dict`, *optional*):
Expand Down Expand Up @@ -246,6 +251,7 @@ def pil_to_base64(image):
"top_k": top_k,
"min_p": min_p,
"max_tokens": max_tokens,
"truncate_prompt_tokens": truncate_prompt_tokens,
"guided_decoding_regex": guided_decoding_regex,
"generation_kwargs": generation_kwargs or {},
},
Expand Down
5 changes: 5 additions & 0 deletions trl/scripts/vllm_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,7 @@ class GenerateRequest(BaseModel):
top_k: int = -1
min_p: float = 0.0
max_tokens: int = 16
truncate_prompt_tokens: Optional[int] = None
guided_decoding_regex: Optional[str] = None
generation_kwargs: dict = field(default_factory=dict)

Expand Down Expand Up @@ -525,6 +526,9 @@ async def generate(request: GenerateRequest):
- `min_p` (`float`, *optional*, defaults to `0.0`): Minimum probability threshold for sampling.
- `max_tokens` (`int`, *optional*, defaults to `16`): Maximum number of tokens to generate for each
completion.
- `truncate_prompt_tokens` (`int`, *optional*): If set to `-1`, will use the truncation size supported
by the model. If set to an integer k, will use only the last k tokens from the prompt (i.e., left
truncation). If set to `None`, truncation is disabled.
- `guided_decoding_regex` (`str`, *optional*): A regex pattern for guided decoding. If provided, the
model will only generate tokens that match this regex pattern.
- `generation_kwargs` (`dict`, *optional*): Additional generation parameters to pass to the vLLM
Expand Down Expand Up @@ -575,6 +579,7 @@ async def generate(request: GenerateRequest):
"top_k": request.top_k,
"min_p": request.min_p,
"max_tokens": request.max_tokens,
"truncate_prompt_tokens": request.truncate_prompt_tokens,
"guided_decoding": guided_decoding,
"logprobs": 0,
}
Expand Down
88 changes: 22 additions & 66 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import inspect
import os
import re
import textwrap
from collections import defaultdict, deque
from contextlib import nullcontext
Expand Down Expand Up @@ -71,7 +70,6 @@
shuffle_sequence_dict,
split_pixel_values_by_grid,
split_tensor_dict,
truncate_with_protected_tokens,
unsplit_pixel_values_by_grid,
)

Expand Down Expand Up @@ -275,7 +273,7 @@ def __init__(

# Processing class
if processing_class is None:
processing_class = AutoProcessor.from_pretrained(model.config._name_or_path)
processing_class = AutoProcessor.from_pretrained(model.config._name_or_path, truncation_side="left")

# Handle pad token for processors or tokenizers
if isinstance(processing_class, ProcessorMixin):
Expand All @@ -291,10 +289,6 @@ def __init__(
self.pad_token = tokenizer.pad_token
self.pad_token_id = tokenizer.pad_token_id
self.eos_token_id = tokenizer.eos_token_id
self.image_token = getattr(processing_class, "image_token", None)
self.image_token_id = getattr(processing_class, "image_token_id", None)
self.vision_start_token_id = getattr(model.config, "vision_start_token_id", None)
self.vision_end_token_id = getattr(model.config, "vision_end_token_id", None)

# Reward functions
if not isinstance(reward_funcs, list):
Expand Down Expand Up @@ -1092,58 +1086,12 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
]

prompt_inputs = self.processing_class(
text=prompts_text,
return_tensors="pt",
padding=True,
padding_side="left",
add_special_tokens=False,
**kwargs,
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}

if self.max_prompt_length is not None:
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]

# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
# Then we decode those tokens back into text. We set `skip_special_tokens=False` because some special
# tokens are needed for generation.
protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id]
protected = [token for token in protected if token is not None]
prompt_ids = [truncate_with_protected_tokens(ids, self.max_prompt_length, protected) for ids in prompt_ids]

prompts_text = self.processing_class.batch_decode(
prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)

# The chat template sometimes inserts a single image token into the prompt text. However, when this text is
# later tokenized, the single image token string is expanded into multiple image token IDs, depending on the
# image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We
# collapse them back into a single token string to match the original chat template in case it originally
# applies it. Otherwise, it assumes that the chat template uses only vision_start_token_id to indicate images
# (e.g. Gemma 3) and removes all image_token instances and vision_end_token_id as well, leaving only
# the vision_start_token_id (e.g. <start_of_image>).
if self.image_token is not None:
escaped_img_token = re.escape(self.image_token)
# Search for the image token in the chat template
if re.search(escaped_img_token, self.processing_class.chat_template):
prompts_text = [
re.sub(rf"({escaped_img_token})+", self.image_token, text) for text in prompts_text
]
else:
# If the chat template doesn't use the image token, we remove all instances of it + vision_end_token_id
if self.vision_end_token_id is not None:
escaped_eoi_token = re.escape(
self.processing_class.tokenizer.decode([self.vision_end_token_id])
)
prompts_text = [
re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text
]
else:
# If vision_end_token_id is None, just remove the image tokens
prompts_text = [re.sub(rf"({escaped_img_token})+", "", text) for text in prompts_text]
if images is not None:
prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
else:
forward_kwargs = {}

# Generate completions using either vLLM or regular generation
if self.use_vllm:
Expand Down Expand Up @@ -1185,6 +1133,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
top_k=-1 if self.top_k is None else self.top_k,
min_p=0.0 if self.min_p is None else self.min_p,
max_tokens=self.max_completion_length,
truncate_prompt_tokens=self.max_prompt_length,
guided_decoding_regex=self.guided_decoding_regex,
generation_kwargs=self.args.generation_kwargs,
)
Expand Down Expand Up @@ -1223,6 +1172,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
"top_k": -1 if self.top_k is None else self.top_k,
"min_p": 0.0 if self.min_p is None else self.min_p,
"max_tokens": self.max_completion_length,
"truncate_prompt_tokens": self.max_prompt_length,
"guided_decoding": guided_decoding,
"logprobs": 0, # only return the logprob of the generated token
}
Expand Down Expand Up @@ -1319,7 +1269,17 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):

else:
# Regular generation path
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
generate_inputs = self.processing_class(
text=prompts_text,
return_tensors="pt",
padding=True,
padding_side="left",
max_length=self.max_prompt_length,
truncation=True,
add_special_tokens=False,
**kwargs,
)
generate_inputs = super()._prepare_inputs(generate_inputs)

with (
profiling_context(self, "transformers.generate"),
Expand All @@ -1330,15 +1290,11 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
):
prompt_completion_ids = unwrapped_model.generate(
input_ids=prompt_ids,
attention_mask=prompt_mask,
**forward_kwargs,
generation_config=self.generation_config,
disable_compile=True,
**generate_inputs, generation_config=self.generation_config, disable_compile=True
)
# Compute prompt length and extract completion ids
prompt_ids, prompt_mask = generate_inputs["input_ids"], generate_inputs["attention_mask"]
prompt_length = prompt_ids.size(1)
prompt_ids = prompt_completion_ids[:, :prompt_length]
completion_ids = prompt_completion_ids[:, prompt_length:]

# Mask everything after the first EOS token
Expand Down
Loading
Loading