Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ def test_with_scalar(self):
class SplitPixelValuesByGridTester(TrlTestCase):
def test_split_correctly_0(self):
batch = {
"image_split_sizes": [4, 4],
"image_grid_thw": torch.tensor([[1, 2, 2], [1, 2, 2]]),
"pixel_values": torch.arange(8 * 3).reshape(8, 3), # Shape: [8, 3]
}
result = split_pixel_values_by_grid(batch)
Expand All @@ -884,7 +884,7 @@ def test_split_correctly_0(self):

def test_split_correctly_1(self):
batch = {
"image_split_sizes": [4, 8],
"image_grid_thw": torch.tensor([[1, 2, 2], [1, 2, 4]]),
"pixel_values": torch.arange(12 * 3).reshape(12, 3), # Shape: [12, 3]
}
result = split_pixel_values_by_grid(batch)
Expand All @@ -900,7 +900,7 @@ def test_missing_keys(self):

def test_mismatched_length(self):
batch = {
"image_split_sizes": torch.tensor([2, 2]), # Total = 4
"image_grid_thw": torch.tensor([[1, 1, 2], [1, 2, 1]]), # Total = 8
"pixel_values": torch.randn(3, 5), # Only 3 rows
}
with self.assertRaises(ValueError):
Expand Down
17 changes: 1 addition & 16 deletions trl/experimental/gfpo/gfpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,13 @@ def _generate_and_score_completions(self, inputs):
# [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}]
kwargs = {}
has_images = "image" in inputs[0]
image_split_sizes = None
if has_images:
images = [example.get("image") for example in inputs]
kwargs = {"images": [[img] for img in images]}
for prompt in prompts:
if isinstance(prompt, list): # i.e., when using conversational data
prepare_multimodal_messages(prompt, num_images=1)

if hasattr(self.processing_class, "_get_num_multimodal_tokens"):
image_sizes = [(image.height, image.width) for image in images]
multimodal_extra_data = self.processing_class._get_num_multimodal_tokens(image_sizes)
image_split_sizes = multimodal_extra_data.num_image_patches

prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]

prompt_inputs = self.processing_class(
Expand All @@ -116,13 +110,9 @@ def _generate_and_score_completions(self, inputs):
add_special_tokens=False,
**kwargs,
)
prompt_inputs = super(_GRPOTrainer, self)._prepare_inputs(prompt_inputs)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much cleaner!

prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]

if "image_grid_thw" in prompt_inputs and image_split_sizes is None:
# Fallback for VLMs that require image_grid_thw but don't provide _get_num_multimodal_tokens
image_split_sizes = prompt_inputs["image_grid_thw"].prod(dim=1).tolist()

if self.max_prompt_length is not None:
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
# Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text,
Expand Down Expand Up @@ -407,7 +397,6 @@ def _generate_and_score_completions(self, inputs):
image_grid_thw=prompt_inputs.get("image_grid_thw"),
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
image_split_sizes=image_split_sizes,
)
else:
old_per_token_logps = None
Expand All @@ -432,7 +421,6 @@ def _generate_and_score_completions(self, inputs):
image_grid_thw=prompt_inputs.get("image_grid_thw"),
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
image_split_sizes=image_split_sizes,
)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
Expand All @@ -446,7 +434,6 @@ def _generate_and_score_completions(self, inputs):
image_grid_thw=prompt_inputs.get("image_grid_thw"),
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
image_split_sizes=image_split_sizes,
)
else:
ref_per_token_logps = None
Expand Down Expand Up @@ -652,6 +639,4 @@ def _generate_and_score_completions(self, inputs):
output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"]
if "image_sizes" in prompt_inputs:
output["image_sizes"] = prompt_inputs["image_sizes"]
if image_split_sizes is not None:
output["image_split_sizes"] = image_split_sizes
return output
31 changes: 6 additions & 25 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,6 @@ def _get_per_token_logps_and_entropies(
image_grid_thw=None,
pixel_attention_mask=None,
image_sizes=None,
image_split_sizes=None,
) -> dict[str, Optional[torch.Tensor]]:
"""Compute log-probs and (optionally) entropies for each token."""
batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak
Expand All @@ -804,15 +803,13 @@ def _get_per_token_logps_and_entropies(
# Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't)
model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch}

if image_grid_thw is not None:
if image_grid_thw is not None and pixel_values is not None:
model_inputs["image_grid_thw"] = image_grid_thw[start : start + batch_size]
if pixel_values is not None:
if image_split_sizes is not None:
start_pixel_idx = sum(image_split_sizes[:start])
end_pixel_idx = sum(image_split_sizes[: start + batch_size])
model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx]
else:
model_inputs["pixel_values"] = pixel_values[start : start + batch_size]
start_pixel_idx = image_grid_thw[:start].prod(-1).sum().item()
end_pixel_idx = image_grid_thw[: start + batch_size].prod(-1).sum().item()
model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx]
elif pixel_values is not None:
model_inputs["pixel_values"] = pixel_values[start : start + batch_size]
if pixel_attention_mask is not None:
model_inputs["pixel_attention_mask"] = pixel_attention_mask[start : start + batch_size]
if image_sizes is not None:
Expand Down Expand Up @@ -1078,19 +1075,13 @@ def _generate_and_score_completions(
# [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}]
kwargs = {}
has_images = "image" in inputs[0]
image_split_sizes = None
if has_images:
images = [example.get("image") for example in inputs]
kwargs = {"images": [[img] for img in images]}
for prompt in prompts:
if isinstance(prompt, list): # i.e., when using conversational data
prepare_multimodal_messages(prompt, num_images=1)

if hasattr(self.processing_class, "_get_num_multimodal_tokens"):
image_sizes = [(image.height, image.width) for image in images]
multimodal_extra_data = self.processing_class._get_num_multimodal_tokens(image_sizes)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it also avoids calling a private method

image_split_sizes = multimodal_extra_data.num_image_patches

prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]

prompt_inputs = self.processing_class(
Expand All @@ -1104,10 +1095,6 @@ def _generate_and_score_completions(
prompt_inputs = super()._prepare_inputs(prompt_inputs)
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]

if "image_grid_thw" in prompt_inputs and image_split_sizes is None:
# Fallback for VLMs that require image_grid_thw but don't provide _get_num_multimodal_tokens
image_split_sizes = prompt_inputs["image_grid_thw"].prod(dim=1).tolist()

if self.max_prompt_length is not None:
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
# Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text,
Expand Down Expand Up @@ -1392,7 +1379,6 @@ def _generate_and_score_completions(
image_grid_thw=prompt_inputs.get("image_grid_thw"),
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
image_split_sizes=image_split_sizes,
)
else:
old_per_token_logps = None
Expand All @@ -1417,7 +1403,6 @@ def _generate_and_score_completions(
image_grid_thw=prompt_inputs.get("image_grid_thw"),
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
image_split_sizes=image_split_sizes,
)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
Expand All @@ -1431,7 +1416,6 @@ def _generate_and_score_completions(
image_grid_thw=prompt_inputs.get("image_grid_thw"),
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
image_split_sizes=image_split_sizes,
)
else:
ref_per_token_logps = None
Expand Down Expand Up @@ -1580,8 +1564,6 @@ def _generate_and_score_completions(
output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"]
if "image_sizes" in prompt_inputs:
output["image_sizes"] = prompt_inputs["image_sizes"]
if image_split_sizes is not None:
output["image_split_sizes"] = image_split_sizes
return output

def compute_liger_loss(self, unwrapped_model, inputs):
Expand Down Expand Up @@ -1656,7 +1638,6 @@ def _compute_loss(self, model, inputs):
image_grid_thw=inputs.get("image_grid_thw"),
pixel_attention_mask=inputs.get("pixel_attention_mask"),
image_sizes=inputs.get("image_sizes"),
image_split_sizes=inputs.get("image_split_sizes"),
)

if self.top_entropy_quantile < 1.0:
Expand Down
31 changes: 6 additions & 25 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,6 @@ def _get_per_token_logps_and_entropies(
image_grid_thw=None,
pixel_attention_mask=None,
image_sizes=None,
image_split_sizes=None,
) -> dict[str, Optional[torch.Tensor]]:
"""Compute log-probs and (optionally) entropies for each token."""
batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak
Expand All @@ -790,15 +789,13 @@ def _get_per_token_logps_and_entropies(
# Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't)
model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch}

if image_grid_thw is not None:
if image_grid_thw is not None and pixel_values is not None:
model_inputs["image_grid_thw"] = image_grid_thw[start : start + batch_size]
if pixel_values is not None:
if image_split_sizes is not None:
start_pixel_idx = sum(image_split_sizes[:start])
end_pixel_idx = sum(image_split_sizes[: start + batch_size])
model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx]
else:
model_inputs["pixel_values"] = pixel_values[start : start + batch_size]
start_pixel_idx = image_grid_thw[:start].prod(-1).sum().item()
end_pixel_idx = image_grid_thw[: start + batch_size].prod(-1).sum().item()
model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx]
elif pixel_values is not None:
model_inputs["pixel_values"] = pixel_values[start : start + batch_size]
if pixel_attention_mask is not None:
model_inputs["pixel_attention_mask"] = pixel_attention_mask[start : start + batch_size]
if image_sizes is not None:
Expand Down Expand Up @@ -1064,19 +1061,13 @@ def _generate_and_score_completions(
# [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}]
kwargs = {}
has_images = "image" in inputs[0]
image_split_sizes = None
if has_images:
images = [example.get("image") for example in inputs]
kwargs = {"images": [[img] for img in images]}
for prompt in prompts:
if isinstance(prompt, list): # i.e., when using conversational data
prepare_multimodal_messages(prompt, num_images=1)

if hasattr(self.processing_class, "_get_num_multimodal_tokens"):
image_sizes = [(image.height, image.width) for image in images]
multimodal_extra_data = self.processing_class._get_num_multimodal_tokens(image_sizes)
image_split_sizes = multimodal_extra_data.num_image_patches

prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]

prompt_inputs = self.processing_class(
Expand All @@ -1090,10 +1081,6 @@ def _generate_and_score_completions(
prompt_inputs = super()._prepare_inputs(prompt_inputs)
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]

if "image_grid_thw" in prompt_inputs and image_split_sizes is None:
# Fallback for VLMs that require image_grid_thw but don't provide _get_num_multimodal_tokens
image_split_sizes = prompt_inputs["image_grid_thw"].prod(dim=1).tolist()

if self.max_prompt_length is not None:
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
# Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text,
Expand Down Expand Up @@ -1346,7 +1333,6 @@ def _generate_and_score_completions(
image_grid_thw=prompt_inputs.get("image_grid_thw"),
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
image_split_sizes=image_split_sizes,
)
old_logps = (old_per_token_logps * completion_mask).sum(1) # mask out padding and tokens after EOS

Expand All @@ -1363,7 +1349,6 @@ def _generate_and_score_completions(
image_grid_thw=prompt_inputs.get("image_grid_thw"),
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
image_split_sizes=image_split_sizes,
)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
Expand All @@ -1377,7 +1362,6 @@ def _generate_and_score_completions(
image_grid_thw=prompt_inputs.get("image_grid_thw"),
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
image_split_sizes=image_split_sizes,
)
else:
ref_per_token_logps = None
Expand Down Expand Up @@ -1498,8 +1482,6 @@ def _generate_and_score_completions(
output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"]
if "image_sizes" in prompt_inputs:
output["image_sizes"] = prompt_inputs["image_sizes"]
if image_split_sizes is not None:
output["image_split_sizes"] = image_split_sizes
return output

@profiling_decorator
Expand Down Expand Up @@ -1527,7 +1509,6 @@ def _compute_loss(self, model, inputs):
image_grid_thw=inputs.get("image_grid_thw"),
pixel_attention_mask=inputs.get("pixel_attention_mask"),
image_sizes=inputs.get("image_sizes"),
image_split_sizes=inputs.get("image_split_sizes"),
)

logps = (per_token_logps * completion_mask).sum(1) # mask out padding and tokens after EOS
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1783,10 +1783,10 @@ def split_pixel_values_by_grid(batch: dict[str, torch.Tensor]) -> dict[str, Unio
Splits `batch["pixel_values"]` into a list of tensors based on the product of each row in
`batch["image_grid_thw"]`, while keeping other entries unchanged.
"""
if "image_split_sizes" not in batch or "pixel_values" not in batch:
if "image_grid_thw" not in batch or "pixel_values" not in batch:
return batch

lengths = batch["image_split_sizes"] # [batch_size]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I understand correctly, that image_split_sizes was only really used in this helper method to extract the image lengths? If so, then I agree it's redundant with image_grid_thw

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, precisely

lengths = batch["image_grid_thw"].prod(-1).tolist() # [batch_size]
pixel_values = batch["pixel_values"] # [total, feature_dim]

if sum(lengths) != pixel_values.size(0):
Expand Down
Loading