Skip to content

Commit

Permalink
Support multiple vllms (modelscope#3202)
Browse files Browse the repository at this point in the history
  • Loading branch information
tastelikefeet authored Feb 21, 2025
1 parent 16ae13c commit 8921d9b
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 41 deletions.
3 changes: 2 additions & 1 deletion docs/source/Instruction/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,8 @@ reward模型参数将在PPO、GRPO中使用。
- log_completions: 是否记录训练中的模型生成内容,搭配 `--report_to wandb` 使用。默认为False
- 提示:若没有设置`--report_to wandb`,则会在checkpoint中创建`completions.jsonl`来存储生成内容
- use_vllm: 是否使用vLLM作为GRPO生成的infer_backend,默认为False
- vllm_device: 设置vLLM部署的设备,比如部署在卡0上,则`cuda:1`, 默认为`auto`, 即使用最后一张卡
- num_infer_workers: 每个node上推理worker数量,仅对vllm或者lmdeploy时有效
- vllm_device: 设置vLLM部署的设备,可以设置为`auto`,代表按照num_infer_workers数量使用最后的几张卡,否则请传入和num_infer_workers相等数量的设备,例如`--vllm_device cuda:1 cuda:2`
- vllm_gpu_memory_utilization: vllm透传参数,默认为0.9
- vllm_max_model_len: vllm透传参数,默认为None
- vllm_max_num_seqs: vllm透传参数,默认为256
Expand Down
3 changes: 2 additions & 1 deletion docs/source_en/Instruction/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,8 @@ The meanings of the following parameters can be referenced [here](https://huggin
- log_completions: Whether to log the model-generated content during training, to be used in conjunction with `--report_to wandb`, default is False.
- Note: If `--report_to wandb` is not set, a `completions.jsonl` will be created in the checkpoint to store the generated content.
- use_vllm: Whether to use vLLM as the infer_backend for GRPO generation, default is False.
- vllm_device: Set the device for vLLM deployment. For example, if deployed on card 0, use `cuda:0`; default is `auto`, which means using the last available GPU.
- num_infer_workers: The number of inference workers per node. This setting is only effective when using vLLM or lmdeploy.
- vllm_device: Configures the devices for deploying vLLM. You can set it to auto, which will allocate the last few GPUs based on the value of num_infer_workers. Alternatively, specify a number of devices equal to num_infer_workers. For example: --vllm_device cuda:1 cuda:2.
- vllm_gpu_memory_utilization: vLLM passthrough parameter, default is 0.9.
- vllm_max_model_len: vLLM passthrough parameter, default is None.
- vllm_max_num_seqs: vLLM passthrough parameter, default is 256.
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/argument/rlhf_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class GRPOArguments(GRPOArgumentsMixin):

# vLLM in GRPO
use_vllm: bool = False
vllm_device: Optional[str] = 'auto' # 'cuda:0'
vllm_device: List[str] = field(default_factory=lambda: ['auto'])
vllm_gpu_memory_utilization: float = 0.9
vllm_max_model_len: Optional[int] = None

Expand Down
1 change: 1 addition & 0 deletions swift/trainers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def place_model_on_device(self):
class GRPOArgumentsMixin:

# vllm_device, vllm_gpu_memory_utilization, and vllm_max_model_len are defined in HfGRPOConfig.
num_infer_workers: int = 1
vllm_max_num_seqs: int = 256
vllm_enforce_eager: bool = False
vllm_limit_mm_per_prompt: Optional[Union[dict, str]] = None # '{"image": 5, "video": 2}'
Expand Down
127 changes: 91 additions & 36 deletions swift/trainers/rlhf_trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@
from collections import defaultdict
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Union
from unittest.mock import patch

import numpy as np
import torch
import torch.nn as nn
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
from accelerate.utils import gather, gather_object, is_peft_model, set_seed
from transformers import PreTrainedModel
from transformers.utils.versions import require_version
from trl import GRPOTrainer as HFGRPOTrainer
from trl.models import unwrap_model_for_generation

from swift.llm import InferRequest, RequestConfig, RowPreprocessor, to_device
from swift.plugin import orms
from swift.utils import (JsonlWriter, get_device, get_device_count, get_dist_setting, get_logger, is_lmdeploy_available,
is_vllm_available, is_wandb_available)
from swift.utils import (JsonlWriter, get_device, get_device_count, get_dist_setting, get_logger, get_node_setting,
is_lmdeploy_available, is_vllm_available, is_wandb_available)
from ..mixin import SwiftMixin
from .rlhf_mixin import RLHFTrainerMixin

Expand Down Expand Up @@ -109,44 +109,45 @@ def __init__(self,
set_seed(args.seed, device_specific=True)

if use_vllm or use_lmdeploy:
if self.accelerator.is_main_process:
if self.infer_rank >= 0:
fast_infer_device = self.args.vllm_device or self.args.lmdeploy_device
if fast_infer_device == 'auto':
if fast_infer_device[0] == 'auto':
if get_device_count() == 1:
fast_infer_device = get_device() # particular case when training with only 1 GPU: share it
fast_infer_device = [get_device()] # particular case when training with only 1 GPU: share it
else:
local_world_size = get_dist_setting()[3]
fast_infer_device = get_device(local_world_size) # take the next GPU idx
# Check that the requested device is available
if fast_infer_device.split(':')[0] in {'cuda', 'npu'
} and int(fast_infer_device.split(':')[1]) >= get_device_count():
raise ValueError(
f'The requested device for vllm ({fast_infer_device}) is not available. '
f'You are likely using vLLM '
'without restricting the number of GPUs for training. Set the `--num_processes` argument to a '
'value lower than the number of GPUs available on your machine—typically, reducing it by one '
f'is sufficient. In your case: `--num_processes {get_device_count() - 1}`.')
# Check that the requested device is not also used for training
if fast_infer_device in {get_device(idx) for idx in range(self.accelerator.num_processes)}:
logger.warning(
f'The requested device {fast_infer_device} is also used for training. '
f'This may lead to unexpected behavior. It is recommended to use a dedicated device for vLLM.')
fast_infer_device = []
for idx in range(get_device_count() - self.args.num_infer_workers, get_device_count()):
fast_infer_device.append(get_device(idx))

for _device in fast_infer_device:
# Check that the requested device is available
if _device.split(':')[0] in {'cuda', 'npu'} and int(_device.split(':')[1]) >= get_device_count():
raise ValueError(f'The requested device for vllm ({_device}) is not available. '
f'You are likely using vLLM '
'without restricting the number of GPUs for training. '
'Set the `--num_processes` argument to a '
'value lower than the number of GPUs available on your machine—typically, '
'reducing it by one is sufficient. '
f'In your case: `--num_processes {get_device_count() - 1}`.')
# Check that the requested device is not also used for training
if _device in {get_device(idx) for idx in range(self.accelerator.num_processes)}:
logger.warning(f'The requested device {_device} is also used for training. '
f'This may lead to unexpected behavior. '
f'It is recommended to use a dedicated device for vLLM.')

if use_vllm:
if not is_vllm_available():
raise ImportError('vLLM is not available and `use_vllm` is set to True. '
'Please install vLLM with `pip install vllm` to use it.')
from swift.llm import VllmEngine
world_size_patch = patch('torch.distributed.get_world_size', return_value=1)
profiling_patch = patch(
'vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling',
return_value=None)
from swift.tuners import Swift
with world_size_patch, profiling_patch, Swift.grpo_context(model, self.template.processor):
from swift.llm.utils import patch_vllm
with patch_vllm(), Swift.grpo_context(model, self.template.processor):
self.engine = VllmEngine(
model.model_dir,
model.model_info.torch_dtype,
model_type=model.model_meta.model_type,
device=fast_infer_device,
device=fast_infer_device[self.local_infer_rank],
gpu_memory_utilization=args.vllm_gpu_memory_utilization,
enable_prefix_caching=args.vllm_enable_prefix_caching,
max_num_seqs=args.vllm_max_num_seqs,
Expand All @@ -169,7 +170,7 @@ def __init__(self,
from swift.llm import LmdeployEngine
from swift.tuners import Swift
with Swift.grpo_context(model, self.template.processor):
fast_infer_device = int(fast_infer_device.split(':')[1])
fast_infer_device = int(fast_infer_device[self.local_infer_rank].split(':')[1])
self.engine = LmdeployEngine(
model.model_dir,
model.model_info.torch_dtype,
Expand Down Expand Up @@ -203,6 +204,40 @@ def __init__(self,
self.log_completions = args.log_completions
self.jsonl_writer = JsonlWriter(os.path.join(self.args.output_dir, 'completions.jsonl'))

@property
def infer_rank(self):
rank, local_rank, world_size, local_world_size = get_dist_setting()
assert local_world_size % self.args.num_infer_workers == 0
assert local_world_size + self.args.num_infer_workers == get_device_count()
step = local_world_size // self.args.num_infer_workers
for _vllm_rank in range(self.args.num_infer_workers):
_assigned = _vllm_rank * step
if local_rank == _assigned:
return get_node_setting()[0] * self.args.num_infer_workers + _vllm_rank

return -1

@property
def local_infer_rank(self):
rank, local_rank, world_size, local_world_size = get_dist_setting()
assert local_world_size % self.args.num_infer_workers == 0
assert local_world_size + self.args.num_infer_workers == get_device_count()
step = local_world_size // self.args.num_infer_workers
for _vllm_rank in range(self.args.num_infer_workers):
_assigned = _vllm_rank * step
if local_rank == _assigned:
return _vllm_rank

return -1

@staticmethod
def round_robin(num_reqs, nodes):
distribution = [[] for _ in range(nodes)]
for idx in range(num_reqs):
node_id = idx % nodes
distribution[node_id].append(idx)
return distribution

@staticmethod
@contextmanager
def _template_context(template):
Expand Down Expand Up @@ -242,7 +277,7 @@ def _move_model_to_vllm_lmdeploy(self):
}
else:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
if self.infer_rank >= 0:
if self.args.use_vllm:
llm_model = self.engine.engine.engine.model_executor.driver_worker.model_runner.model
else:
Expand All @@ -253,9 +288,20 @@ def _move_model_to_vllm_lmdeploy(self):
if is_peft_model(unwrapped_model):
unwrapped_model.unmerge_adapter()

@staticmethod
def reorder_outputs(outputs, distributed_idx):
index_to_output = {}
current_position = 0
for output_idx in distributed_idx:
for idx in output_idx:
index_to_output[idx] = outputs[current_position]
current_position += 1

return [index_to_output[idx] for idx in sorted(index_to_output.keys())]

def _prepare_inputs(self, inputs) -> Dict[str, Union[torch.Tensor, Any]]:
device = self.accelerator.device

rank, local_rank, world_size, local_world_size = get_dist_setting()
# Generate completions using either vLLM or regular generation
if self.args.use_vllm or self.args.use_lmdeploy:
# First, have main process load weights if needed
Expand All @@ -264,14 +310,23 @@ def _prepare_inputs(self, inputs) -> Dict[str, Union[torch.Tensor, Any]]:
self._last_loaded_step = self.state.global_step
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_inputs = gather_object(inputs)
if self.accelerator.is_main_process:
outputs = self.engine.infer(all_inputs, self.request_config, use_tqdm=False)
# Distribute inputs to different workers
# for example, 2 workers, 6 inputs, 0/2/4 dispatch to the first worker
# 1/3/5 dispatch to the second worker
# trying to shuffle and average the length
distributed_idx = self.round_robin(len(all_inputs), get_node_setting()[1] * self.args.num_infer_workers)
if self.infer_rank >= 0:
outputs = self.engine.infer(
np.array(all_inputs)[distributed_idx[self.infer_rank]], self.request_config, use_tqdm=False)
else:
outputs = [None] * len(all_inputs)
outputs = []

outputs = gather_object(outputs)
outputs = self.reorder_outputs(outputs, distributed_idx)

# Broadcast the completions from the main process to all processes, ensuring each process receives its
# corresponding slice.
outputs = broadcast_object_list(outputs, from_process=0)
# outputs = broadcast_object_list(outputs, from_process=0)
else:
# Regular generation path
is_multimodal = self.model.model_meta.is_multimodal
Expand Down
5 changes: 3 additions & 2 deletions swift/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from .env import (get_dist_setting, get_pai_tensorboard_dir, is_deepspeed_enabled, is_dist, is_dist_ta, is_local_master,
is_master, is_mp, is_mp_ddp, is_pai_training_job, torchacc_trim_graph, use_hf_hub, use_torchacc)
from .env import (get_dist_setting, get_node_setting, get_pai_tensorboard_dir, is_deepspeed_enabled, is_dist,
is_dist_ta, is_local_master, is_master, is_mp, is_mp_ddp, is_pai_training_job, torchacc_trim_graph,
use_hf_hub, use_torchacc)
from .import_utils import (is_liger_available, is_lmdeploy_available, is_megatron_available, is_swanlab_available,
is_unsloth_available, is_vllm_available, is_wandb_available, is_xtuner_available)
from .io_utils import JsonlWriter, append_to_jsonl, download_ms_file, get_file_mm_type, read_from_jsonl, write_to_jsonl
Expand Down
6 changes: 6 additions & 0 deletions swift/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ def get_dist_setting() -> Tuple[int, int, int, int]:
return rank, local_rank, world_size, local_world_size


def get_node_setting():
node_rank = int(os.getenv('NODE_RANK', 0))
nnodes = int(os.getenv('NNODES', 1))
return node_rank, nnodes


def is_local_master():
local_rank = get_dist_setting()[1]
return local_rank in {-1, 0}
Expand Down

0 comments on commit 8921d9b

Please sign in to comment.