generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Open
Labels
Description
Reproduction
@qgallouedec Thank you for your hard work on #4113
New issues:
Test Code
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
dataset = load_dataset("custom_with multiple images in prompt ", split="train")
# Define the reward function, which rewards completions that are close to 20 characters
def reward_len(completions, **kwargs):
return [-abs(20 - len(completion)) for completion in completions]
training_args = GRPOConfig(output_dir="Llava7b-GRPO")
trainer = GRPOTrainer(
model="[Qwen/Qwen2-0.5B-Instruct](https://huggingface.co/llava-hf/llava-1.5-7b-hf)",
reward_funcs=reward_len,
args=training_args,
train_dataset=dataset,
)
trainer.train()Issue 1 - Cannot train GRPO with multiple images and Llava-1.5b:
- Output:
ValueError: Image features and image tokens do not match: tokens: 3456, features 2359296 - Reason: Missing
image_grid_thwvariable in batch input skips the split processing onsplit_pixel_values_by_gridand make the image size == 1 during_get_per_token_logps_and_entropies - My implementation:
on trl.trainer.utils:
def split_pixel_values_by_grid(batch: dict[str, torch.Tensor]) -> dict[str, Union[torch.Tensor, list[torch.Tensor]]]:
"""
Splits `batch["pixel_values"]` into a list of tensors based on the product of each row in `batch["image_grid_thw"]`
and batch["num_images"] while keeping other entries unchanged.
"""
if "pixel_values" not in batch or "num_images" not in batch:
return batch # type: ignore
if "image_grid_thw" not in batch:
lengths = sum(batch["num_images"]) * [1]
else:
lengths = batch["image_grid_thw"].prod(-1).tolist() # [num_images]
pixel_values = batch["pixel_values"] # [total, feature_dim]
if sum(lengths) != pixel_values.size(0):
raise ValueError(f"Mismatch: sum(lengths) = {sum(lengths)} != pixel_values.size(0) = {pixel_values.size(0)}")
boundaries = [0, *accumulate(batch["num_images"])] # [3, 4, 5] -> [0, 3, 7, 12]
sections = [sum(lengths[boundaries[i] : boundaries[i + 1]]) for i in range(len(batch["num_images"]))]
split_values = list(torch.split(batch["pixel_values"], sections, dim=0)) # type: ignore
if "image_grid_thw" in batch:
image_grid_thw = list(torch.split(batch["image_grid_thw"], batch["num_images"], dim=0)) # type: ignore
return {**batch, "pixel_values": split_values, "image_grid_thw": image_grid_thw}
return {**batch, "pixel_values": split_values}on trl.trainer.grpo_trainer at GRPOTrainer._get_per_token_logps_and_entropies
elif pixel_values is not None:
cum_imgs = torch.tensor([0] + num_images).cumsum(0) # type: ignore
img_start, img_end = cum_imgs[start], cum_imgs[start + batch_size]
model_inputs["pixel_values"] = pixel_values[img_start:img_end]Issue 2 - Cannot train with vLLM serve
- Output:
output = self.vllm_client.generate(
File "/home/nelson/miniconda/envs/dlm/lib/python3.10/site-packages/trl/extras/vllm_client.py", line 240, in generate
images = [pil_to_base64(img) for img in images] if images else None
File "/home/nelson/miniconda/envs/dlm/lib/python3.10/site-packages/trl/extras/vllm_client.py", line 240, in <listcomp>
images = [pil_to_base64(img) for img in images] if images else None
File "/home/nelson/miniconda/envs/dlm/lib/python3.10/site-packages/trl/extras/vllm_client.py", line 235, in pil_to_base64
image.save(buffer, format="PNG")
AttributeError: 'list' object has no attribute 'save'- Reason: expects a list of images not a list of list of images such as in multiple images.
Lines 233 to 240 in 04fd120
def pil_to_base64(image): buffer = BytesIO() image.save(buffer, format="PNG") img_bytes = buffer.getvalue() return base64.b64encode(img_bytes).decode("utf-8") # Convert PIL images to base64 strings images = [pil_to_base64(img) for img in images] if images else None - Also: seems to expects a multimodal with single image.
Lines 562 to 566 in 04fd120
for prompt, image in zip(request.prompts, request.images): row = {"prompt": prompt} if image is not None: row["multi_modal_data"] = {"image": Image.open(BytesIO(base64.b64decode(image)))} prompts.append(row) - solutions, working on.
Issue 3 - Gradient Checkpointing
- Cannot train with:
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: falsedisplays Gradients will be none. And use_reentrant== true raise has been marked as ready twice error. By the moment I disabled gradient checkpointing, but I will need it in the future, so I check it.
System Info
- Platform: Linux-6.14.0-1015-x86_64-with-glibc2.39
- Python version: 3.10.18
- TRL version: 0.24.0
- PyTorch version: 2.8.0
- accelerator(s): NVIDIA L40S, NVIDIA L40S, NVIDIA L40S, NVIDIA L40S
- Transformers version: 4.57.1
- Accelerate version: 1.11.0
- Accelerate config: not found
- Datasets version: 4.3.0
- HF Hub version: 0.36.0
- bitsandbytes version: 0.47.0
- DeepSpeed version: 0.17.4
- Liger-Kernel version: 0.6.1
- LLM-Blender version: not installed
- OpenAI version: 1.100.2
- PEFT version: 0.17.1
- vLLM version: 0.11.0
Checklist
- I have checked that my issue isn't already filed (see open issues)
- I have included my system information
- Any code provided is minimal, complete, and reproducible (more on MREs)
- Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
- Any traceback provided is complete