-
Notifications
You must be signed in to change notification settings - Fork 2.3k
🟩 Drop image_split_sizes in favour of image_grid_thw
#4111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
552e899
449ef07
c8933aa
52d8bd9
e17ec42
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -791,7 +791,6 @@ def _get_per_token_logps_and_entropies( | |
| image_grid_thw=None, | ||
| pixel_attention_mask=None, | ||
| image_sizes=None, | ||
| image_split_sizes=None, | ||
| ) -> dict[str, Optional[torch.Tensor]]: | ||
| """Compute log-probs and (optionally) entropies for each token.""" | ||
| batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak | ||
|
|
@@ -804,15 +803,13 @@ def _get_per_token_logps_and_entropies( | |
| # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) | ||
| model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} | ||
|
|
||
| if image_grid_thw is not None: | ||
| if image_grid_thw is not None and pixel_values is not None: | ||
| model_inputs["image_grid_thw"] = image_grid_thw[start : start + batch_size] | ||
| if pixel_values is not None: | ||
| if image_split_sizes is not None: | ||
| start_pixel_idx = sum(image_split_sizes[:start]) | ||
| end_pixel_idx = sum(image_split_sizes[: start + batch_size]) | ||
| model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] | ||
| else: | ||
| model_inputs["pixel_values"] = pixel_values[start : start + batch_size] | ||
| start_pixel_idx = image_grid_thw[:start].prod(-1).sum().item() | ||
| end_pixel_idx = image_grid_thw[: start + batch_size].prod(-1).sum().item() | ||
| model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] | ||
| elif pixel_values is not None: | ||
| model_inputs["pixel_values"] = pixel_values[start : start + batch_size] | ||
| if pixel_attention_mask is not None: | ||
| model_inputs["pixel_attention_mask"] = pixel_attention_mask[start : start + batch_size] | ||
| if image_sizes is not None: | ||
|
|
@@ -1078,19 +1075,13 @@ def _generate_and_score_completions( | |
| # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] | ||
| kwargs = {} | ||
| has_images = "image" in inputs[0] | ||
| image_split_sizes = None | ||
| if has_images: | ||
| images = [example.get("image") for example in inputs] | ||
| kwargs = {"images": [[img] for img in images]} | ||
| for prompt in prompts: | ||
| if isinstance(prompt, list): # i.e., when using conversational data | ||
| prepare_multimodal_messages(prompt, num_images=1) | ||
|
|
||
| if hasattr(self.processing_class, "_get_num_multimodal_tokens"): | ||
| image_sizes = [(image.height, image.width) for image in images] | ||
| multimodal_extra_data = self.processing_class._get_num_multimodal_tokens(image_sizes) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it also avoids calling a private method |
||
| image_split_sizes = multimodal_extra_data.num_image_patches | ||
|
|
||
| prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] | ||
|
|
||
| prompt_inputs = self.processing_class( | ||
|
|
@@ -1104,10 +1095,6 @@ def _generate_and_score_completions( | |
| prompt_inputs = super()._prepare_inputs(prompt_inputs) | ||
| prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] | ||
|
|
||
| if "image_grid_thw" in prompt_inputs and image_split_sizes is None: | ||
| # Fallback for VLMs that require image_grid_thw but don't provide _get_num_multimodal_tokens | ||
| image_split_sizes = prompt_inputs["image_grid_thw"].prod(dim=1).tolist() | ||
|
|
||
| if self.max_prompt_length is not None: | ||
| # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. | ||
| # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text, | ||
|
|
@@ -1392,7 +1379,6 @@ def _generate_and_score_completions( | |
| image_grid_thw=prompt_inputs.get("image_grid_thw"), | ||
| pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), | ||
| image_sizes=prompt_inputs.get("image_sizes"), | ||
| image_split_sizes=image_split_sizes, | ||
| ) | ||
| else: | ||
| old_per_token_logps = None | ||
|
|
@@ -1417,7 +1403,6 @@ def _generate_and_score_completions( | |
| image_grid_thw=prompt_inputs.get("image_grid_thw"), | ||
| pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), | ||
| image_sizes=prompt_inputs.get("image_sizes"), | ||
| image_split_sizes=image_split_sizes, | ||
| ) | ||
| else: | ||
| with self.accelerator.unwrap_model(self.model).disable_adapter(): | ||
|
|
@@ -1431,7 +1416,6 @@ def _generate_and_score_completions( | |
| image_grid_thw=prompt_inputs.get("image_grid_thw"), | ||
| pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), | ||
| image_sizes=prompt_inputs.get("image_sizes"), | ||
| image_split_sizes=image_split_sizes, | ||
| ) | ||
| else: | ||
| ref_per_token_logps = None | ||
|
|
@@ -1580,8 +1564,6 @@ def _generate_and_score_completions( | |
| output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"] | ||
| if "image_sizes" in prompt_inputs: | ||
| output["image_sizes"] = prompt_inputs["image_sizes"] | ||
| if image_split_sizes is not None: | ||
| output["image_split_sizes"] = image_split_sizes | ||
| return output | ||
|
|
||
| def compute_liger_loss(self, unwrapped_model, inputs): | ||
|
|
@@ -1656,7 +1638,6 @@ def _compute_loss(self, model, inputs): | |
| image_grid_thw=inputs.get("image_grid_thw"), | ||
| pixel_attention_mask=inputs.get("pixel_attention_mask"), | ||
| image_sizes=inputs.get("image_sizes"), | ||
| image_split_sizes=inputs.get("image_split_sizes"), | ||
| ) | ||
|
|
||
| if self.top_entropy_quantile < 1.0: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1783,10 +1783,10 @@ def split_pixel_values_by_grid(batch: dict[str, torch.Tensor]) -> dict[str, Unio | |
| Splits `batch["pixel_values"]` into a list of tensors based on the product of each row in | ||
| `batch["image_grid_thw"]`, while keeping other entries unchanged. | ||
| """ | ||
| if "image_split_sizes" not in batch or "pixel_values" not in batch: | ||
| if "image_grid_thw" not in batch or "pixel_values" not in batch: | ||
| return batch | ||
|
|
||
| lengths = batch["image_split_sizes"] # [batch_size] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do I understand correctly, that
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, precisely |
||
| lengths = batch["image_grid_thw"].prod(-1).tolist() # [batch_size] | ||
| pixel_values = batch["pixel_values"] # [total, feature_dim] | ||
|
|
||
| if sum(lengths) != pixel_values.size(0): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Much cleaner!