Skip to content

Bottleneck in GRPO training #2887

@ZYM66

Description

@ZYM66

Feature request

The default code only deploys the generation model on a single GPU, which is slow and inefficient, as other processes have to wait for VLLM completion.

Image

GPU7 deployed the vLLM model

source code below:

if self.accelerator.is_main_process:
    vllm_device = self.args.vllm_device
    if vllm_device == "auto":
        if torch.cuda.device_count() == 1:
            vllm_device = "cuda:0"  # particular case when training with onyl 1 GPU: share it
        else:
            vllm_device = f"cuda:{self.accelerator.num_processes}"  # take the next GPU idx
    # Check that the requested device is available
    if vllm_device.split(":")[0] == "cuda" and int(vllm_device.split(":")[1]) >= torch.cuda.device_count():
        raise ValueError(
            f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
            "without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
            "value lower than the number of GPUs available on your machine—typically, reducing it by one "
            f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`."
        )
    # Check that the requested device is not also used for training
    if vllm_device in {f"cuda:{idx}" for idx in range(self.accelerator.num_processes)}:
        warnings.warn(
            f"The requested device {vllm_device} is also being used for training. For higher throughput "
            "and to avoid out-of-memory errors, it is recommended to use a dedicated device for vLLM. "
            "If this is intentional, you may ignore this warning but should adjust "
            "`vllm_gpu_memory_utilization` accordingly."
        )
    # vLLM is not compatible with accelerate. So we need to patch it to make sure we can (1) place the vLLM
    # model on the desired device (world_size_patch) and (2) avoid a test that is not designed for our
    # setting (profiling_patch).
    world_size_patch = patch("torch.distributed.get_world_size", return_value=1)
    profiling_patch = patch(
        "vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling", return_value=None
    )
    with world_size_patch, profiling_patch:
        self.llm = LLM(
            model=model.name_or_path,
            device=vllm_device,
            gpu_memory_utilization=self.args.vllm_gpu_memory_utilization,
            dtype=self.args.vllm_dtype,
            # Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
            # directly reuse the KV cache if it shares the same prefix with one of the existing queries.
            # This is particularly useful here because we generate completions from the same prompts.
            enable_prefix_caching=True,
            max_model_len=self.args.vllm_max_model_len,
        )

Motivation

One GPU inference, too slow!

Your contribution

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions