Skip to content

Issues on GRPO with VLM and vLLM (cont.) #4488

@Fhrozen

Description

@Fhrozen

Reproduction

@qgallouedec Thank you for your hard work on #4113

New issues:

Test Code

from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer

dataset = load_dataset("custom_with multiple images in prompt ", split="train")

# Define the reward function, which rewards completions that are close to 20 characters
def reward_len(completions, **kwargs):
    return [-abs(20 - len(completion)) for completion in completions]

training_args = GRPOConfig(output_dir="Llava7b-GRPO")
trainer = GRPOTrainer(
    model="[Qwen/Qwen2-0.5B-Instruct](https://huggingface.co/llava-hf/llava-1.5-7b-hf)",
    reward_funcs=reward_len,
    args=training_args,
    train_dataset=dataset,
)
trainer.train()

Issue 1 - Cannot train GRPO with multiple images and Llava-1.5b:

  • Output: ValueError: Image features and image tokens do not match: tokens: 3456, features 2359296
  • Reason: Missing image_grid_thw variable in batch input skips the split processing on split_pixel_values_by_grid and make the image size == 1 during _get_per_token_logps_and_entropies
  • My implementation:

on trl.trainer.utils:

def split_pixel_values_by_grid(batch: dict[str, torch.Tensor]) -> dict[str, Union[torch.Tensor, list[torch.Tensor]]]:
    """
    Splits `batch["pixel_values"]` into a list of tensors based on the product of each row in `batch["image_grid_thw"]`
    and batch["num_images"] while keeping other entries unchanged.
    """
    if "pixel_values" not in batch or "num_images" not in batch:
        return batch  # type: ignore

    if "image_grid_thw" not in batch:
        lengths = sum(batch["num_images"]) * [1]
    else:
        lengths = batch["image_grid_thw"].prod(-1).tolist()  # [num_images]
    pixel_values = batch["pixel_values"]  # [total, feature_dim]

    if sum(lengths) != pixel_values.size(0):
        raise ValueError(f"Mismatch: sum(lengths) = {sum(lengths)} != pixel_values.size(0) = {pixel_values.size(0)}")

    boundaries = [0, *accumulate(batch["num_images"])]  # [3, 4, 5] -> [0, 3, 7, 12]
    sections = [sum(lengths[boundaries[i] : boundaries[i + 1]]) for i in range(len(batch["num_images"]))]
    split_values = list(torch.split(batch["pixel_values"], sections, dim=0))  # type: ignore

    if "image_grid_thw" in batch:
        image_grid_thw = list(torch.split(batch["image_grid_thw"], batch["num_images"], dim=0))  # type: ignore
        return {**batch, "pixel_values": split_values, "image_grid_thw": image_grid_thw}

    return {**batch, "pixel_values": split_values}

on trl.trainer.grpo_trainer at GRPOTrainer._get_per_token_logps_and_entropies

elif pixel_values is not None:
              cum_imgs = torch.tensor([0] + num_images).cumsum(0)  # type: ignore
              img_start, img_end = cum_imgs[start], cum_imgs[start + batch_size]
              model_inputs["pixel_values"] = pixel_values[img_start:img_end]

Issue 2 - Cannot train with vLLM serve

  • Output:
    output = self.vllm_client.generate(
  File "/home/nelson/miniconda/envs/dlm/lib/python3.10/site-packages/trl/extras/vllm_client.py", line 240, in generate
    images = [pil_to_base64(img) for img in images] if images else None
  File "/home/nelson/miniconda/envs/dlm/lib/python3.10/site-packages/trl/extras/vllm_client.py", line 240, in <listcomp>
    images = [pil_to_base64(img) for img in images] if images else None
  File "/home/nelson/miniconda/envs/dlm/lib/python3.10/site-packages/trl/extras/vllm_client.py", line 235, in pil_to_base64
    image.save(buffer, format="PNG")
AttributeError: 'list' object has no attribute 'save'
  • Reason:
    def pil_to_base64(image):
    buffer = BytesIO()
    image.save(buffer, format="PNG")
    img_bytes = buffer.getvalue()
    return base64.b64encode(img_bytes).decode("utf-8")
    # Convert PIL images to base64 strings
    images = [pil_to_base64(img) for img in images] if images else None
    expects a list of images not a list of list of images such as in multiple images.
  • Also:
    for prompt, image in zip(request.prompts, request.images):
    row = {"prompt": prompt}
    if image is not None:
    row["multi_modal_data"] = {"image": Image.open(BytesIO(base64.b64decode(image)))}
    prompts.append(row)
    seems to expects a multimodal with single image.
  • solutions, working on.

Issue 3 - Gradient Checkpointing

  • Cannot train with:
gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: false

displays Gradients will be none. And use_reentrant== true raise has been marked as ready twice error. By the moment I disabled gradient checkpointing, but I will need it in the future, so I check it.

System Info

  • Platform: Linux-6.14.0-1015-x86_64-with-glibc2.39
  • Python version: 3.10.18
  • TRL version: 0.24.0
  • PyTorch version: 2.8.0
  • accelerator(s): NVIDIA L40S, NVIDIA L40S, NVIDIA L40S, NVIDIA L40S
  • Transformers version: 4.57.1
  • Accelerate version: 1.11.0
  • Accelerate config: not found
  • Datasets version: 4.3.0
  • HF Hub version: 0.36.0
  • bitsandbytes version: 0.47.0
  • DeepSpeed version: 0.17.4
  • Liger-Kernel version: 0.6.1
  • LLM-Blender version: not installed
  • OpenAI version: 1.100.2
  • PEFT version: 0.17.1
  • vLLM version: 0.11.0

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete

Metadata

Metadata

Assignees

No one assigned

    Labels

    🏋 GRPORelated to GRPO🐛 bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions