From 8921d9b98310d93f9f111af8859358ee32dce687 Mon Sep 17 00:00:00 2001 From: tastelikefeet <58414341+tastelikefeet@users.noreply.github.com> Date: Fri, 21 Feb 2025 15:25:06 +0800 Subject: [PATCH] Support multiple vllms (#3202) --- ...44\350\241\214\345\217\202\346\225\260.md" | 3 +- .../Instruction/Command-line-parameters.md | 3 +- swift/llm/argument/rlhf_args.py | 2 +- swift/trainers/arguments.py | 1 + swift/trainers/rlhf_trainer/grpo_trainer.py | 127 +++++++++++++----- swift/utils/__init__.py | 5 +- swift/utils/env.py | 6 + 7 files changed, 106 insertions(+), 41 deletions(-) diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index 5e2abd7ad..23bbb2ad4 100644 --- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -373,7 +373,8 @@ reward模型参数将在PPO、GRPO中使用。 - log_completions: 是否记录训练中的模型生成内容,搭配 `--report_to wandb` 使用。默认为False - 提示:若没有设置`--report_to wandb`,则会在checkpoint中创建`completions.jsonl`来存储生成内容 - use_vllm: 是否使用vLLM作为GRPO生成的infer_backend,默认为False -- vllm_device: 设置vLLM部署的设备,比如部署在卡0上,则`cuda:1`, 默认为`auto`, 即使用最后一张卡 +- num_infer_workers: 每个node上推理worker数量,仅对vllm或者lmdeploy时有效 +- vllm_device: 设置vLLM部署的设备,可以设置为`auto`,代表按照num_infer_workers数量使用最后的几张卡,否则请传入和num_infer_workers相等数量的设备,例如`--vllm_device cuda:1 cuda:2` - vllm_gpu_memory_utilization: vllm透传参数,默认为0.9 - vllm_max_model_len: vllm透传参数,默认为None - vllm_max_num_seqs: vllm透传参数,默认为256 diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index a17d9931e..0eaa30841 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -384,7 +384,8 @@ The meanings of the following parameters can be referenced [here](https://huggin - log_completions: Whether to log the model-generated content during training, to be used in conjunction with `--report_to wandb`, default is False. - Note: If `--report_to wandb` is not set, a `completions.jsonl` will be created in the checkpoint to store the generated content. - use_vllm: Whether to use vLLM as the infer_backend for GRPO generation, default is False. -- vllm_device: Set the device for vLLM deployment. For example, if deployed on card 0, use `cuda:0`; default is `auto`, which means using the last available GPU. +- num_infer_workers: The number of inference workers per node. This setting is only effective when using vLLM or lmdeploy. +- vllm_device: Configures the devices for deploying vLLM. You can set it to auto, which will allocate the last few GPUs based on the value of num_infer_workers. Alternatively, specify a number of devices equal to num_infer_workers. For example: --vllm_device cuda:1 cuda:2. - vllm_gpu_memory_utilization: vLLM passthrough parameter, default is 0.9. - vllm_max_model_len: vLLM passthrough parameter, default is None. - vllm_max_num_seqs: vLLM passthrough parameter, default is 256. diff --git a/swift/llm/argument/rlhf_args.py b/swift/llm/argument/rlhf_args.py index a2f73db03..2d3907def 100644 --- a/swift/llm/argument/rlhf_args.py +++ b/swift/llm/argument/rlhf_args.py @@ -49,7 +49,7 @@ class GRPOArguments(GRPOArgumentsMixin): # vLLM in GRPO use_vllm: bool = False - vllm_device: Optional[str] = 'auto' # 'cuda:0' + vllm_device: List[str] = field(default_factory=lambda: ['auto']) vllm_gpu_memory_utilization: float = 0.9 vllm_max_model_len: Optional[int] = None diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py index c916ecc10..b5c4e7e09 100644 --- a/swift/trainers/arguments.py +++ b/swift/trainers/arguments.py @@ -72,6 +72,7 @@ def place_model_on_device(self): class GRPOArgumentsMixin: # vllm_device, vllm_gpu_memory_utilization, and vllm_max_model_len are defined in HfGRPOConfig. + num_infer_workers: int = 1 vllm_max_num_seqs: int = 256 vllm_enforce_eager: bool = False vllm_limit_mm_per_prompt: Optional[Union[dict, str]] = None # '{"image": 5, "video": 2}' diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 5c66ac093..c86b9d304 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -5,11 +5,11 @@ from collections import defaultdict from contextlib import contextmanager from typing import Any, Callable, Dict, List, Optional, Union -from unittest.mock import patch +import numpy as np import torch import torch.nn as nn -from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed +from accelerate.utils import gather, gather_object, is_peft_model, set_seed from transformers import PreTrainedModel from transformers.utils.versions import require_version from trl import GRPOTrainer as HFGRPOTrainer @@ -17,8 +17,8 @@ from swift.llm import InferRequest, RequestConfig, RowPreprocessor, to_device from swift.plugin import orms -from swift.utils import (JsonlWriter, get_device, get_device_count, get_dist_setting, get_logger, is_lmdeploy_available, - is_vllm_available, is_wandb_available) +from swift.utils import (JsonlWriter, get_device, get_device_count, get_dist_setting, get_logger, get_node_setting, + is_lmdeploy_available, is_vllm_available, is_wandb_available) from ..mixin import SwiftMixin from .rlhf_mixin import RLHFTrainerMixin @@ -109,44 +109,45 @@ def __init__(self, set_seed(args.seed, device_specific=True) if use_vllm or use_lmdeploy: - if self.accelerator.is_main_process: + if self.infer_rank >= 0: fast_infer_device = self.args.vllm_device or self.args.lmdeploy_device - if fast_infer_device == 'auto': + if fast_infer_device[0] == 'auto': if get_device_count() == 1: - fast_infer_device = get_device() # particular case when training with only 1 GPU: share it + fast_infer_device = [get_device()] # particular case when training with only 1 GPU: share it else: - local_world_size = get_dist_setting()[3] - fast_infer_device = get_device(local_world_size) # take the next GPU idx - # Check that the requested device is available - if fast_infer_device.split(':')[0] in {'cuda', 'npu' - } and int(fast_infer_device.split(':')[1]) >= get_device_count(): - raise ValueError( - f'The requested device for vllm ({fast_infer_device}) is not available. ' - f'You are likely using vLLM ' - 'without restricting the number of GPUs for training. Set the `--num_processes` argument to a ' - 'value lower than the number of GPUs available on your machine—typically, reducing it by one ' - f'is sufficient. In your case: `--num_processes {get_device_count() - 1}`.') - # Check that the requested device is not also used for training - if fast_infer_device in {get_device(idx) for idx in range(self.accelerator.num_processes)}: - logger.warning( - f'The requested device {fast_infer_device} is also used for training. ' - f'This may lead to unexpected behavior. It is recommended to use a dedicated device for vLLM.') + fast_infer_device = [] + for idx in range(get_device_count() - self.args.num_infer_workers, get_device_count()): + fast_infer_device.append(get_device(idx)) + + for _device in fast_infer_device: + # Check that the requested device is available + if _device.split(':')[0] in {'cuda', 'npu'} and int(_device.split(':')[1]) >= get_device_count(): + raise ValueError(f'The requested device for vllm ({_device}) is not available. ' + f'You are likely using vLLM ' + 'without restricting the number of GPUs for training. ' + 'Set the `--num_processes` argument to a ' + 'value lower than the number of GPUs available on your machine—typically, ' + 'reducing it by one is sufficient. ' + f'In your case: `--num_processes {get_device_count() - 1}`.') + # Check that the requested device is not also used for training + if _device in {get_device(idx) for idx in range(self.accelerator.num_processes)}: + logger.warning(f'The requested device {_device} is also used for training. ' + f'This may lead to unexpected behavior. ' + f'It is recommended to use a dedicated device for vLLM.') + if use_vllm: if not is_vllm_available(): raise ImportError('vLLM is not available and `use_vllm` is set to True. ' 'Please install vLLM with `pip install vllm` to use it.') from swift.llm import VllmEngine - world_size_patch = patch('torch.distributed.get_world_size', return_value=1) - profiling_patch = patch( - 'vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling', - return_value=None) from swift.tuners import Swift - with world_size_patch, profiling_patch, Swift.grpo_context(model, self.template.processor): + from swift.llm.utils import patch_vllm + with patch_vllm(), Swift.grpo_context(model, self.template.processor): self.engine = VllmEngine( model.model_dir, model.model_info.torch_dtype, model_type=model.model_meta.model_type, - device=fast_infer_device, + device=fast_infer_device[self.local_infer_rank], gpu_memory_utilization=args.vllm_gpu_memory_utilization, enable_prefix_caching=args.vllm_enable_prefix_caching, max_num_seqs=args.vllm_max_num_seqs, @@ -169,7 +170,7 @@ def __init__(self, from swift.llm import LmdeployEngine from swift.tuners import Swift with Swift.grpo_context(model, self.template.processor): - fast_infer_device = int(fast_infer_device.split(':')[1]) + fast_infer_device = int(fast_infer_device[self.local_infer_rank].split(':')[1]) self.engine = LmdeployEngine( model.model_dir, model.model_info.torch_dtype, @@ -203,6 +204,40 @@ def __init__(self, self.log_completions = args.log_completions self.jsonl_writer = JsonlWriter(os.path.join(self.args.output_dir, 'completions.jsonl')) + @property + def infer_rank(self): + rank, local_rank, world_size, local_world_size = get_dist_setting() + assert local_world_size % self.args.num_infer_workers == 0 + assert local_world_size + self.args.num_infer_workers == get_device_count() + step = local_world_size // self.args.num_infer_workers + for _vllm_rank in range(self.args.num_infer_workers): + _assigned = _vllm_rank * step + if local_rank == _assigned: + return get_node_setting()[0] * self.args.num_infer_workers + _vllm_rank + + return -1 + + @property + def local_infer_rank(self): + rank, local_rank, world_size, local_world_size = get_dist_setting() + assert local_world_size % self.args.num_infer_workers == 0 + assert local_world_size + self.args.num_infer_workers == get_device_count() + step = local_world_size // self.args.num_infer_workers + for _vllm_rank in range(self.args.num_infer_workers): + _assigned = _vllm_rank * step + if local_rank == _assigned: + return _vllm_rank + + return -1 + + @staticmethod + def round_robin(num_reqs, nodes): + distribution = [[] for _ in range(nodes)] + for idx in range(num_reqs): + node_id = idx % nodes + distribution[node_id].append(idx) + return distribution + @staticmethod @contextmanager def _template_context(template): @@ -242,7 +277,7 @@ def _move_model_to_vllm_lmdeploy(self): } else: state_dict = unwrapped_model.state_dict() - if self.accelerator.is_main_process: + if self.infer_rank >= 0: if self.args.use_vllm: llm_model = self.engine.engine.engine.model_executor.driver_worker.model_runner.model else: @@ -253,9 +288,20 @@ def _move_model_to_vllm_lmdeploy(self): if is_peft_model(unwrapped_model): unwrapped_model.unmerge_adapter() + @staticmethod + def reorder_outputs(outputs, distributed_idx): + index_to_output = {} + current_position = 0 + for output_idx in distributed_idx: + for idx in output_idx: + index_to_output[idx] = outputs[current_position] + current_position += 1 + + return [index_to_output[idx] for idx in sorted(index_to_output.keys())] + def _prepare_inputs(self, inputs) -> Dict[str, Union[torch.Tensor, Any]]: device = self.accelerator.device - + rank, local_rank, world_size, local_world_size = get_dist_setting() # Generate completions using either vLLM or regular generation if self.args.use_vllm or self.args.use_lmdeploy: # First, have main process load weights if needed @@ -264,14 +310,23 @@ def _prepare_inputs(self, inputs) -> Dict[str, Union[torch.Tensor, Any]]: self._last_loaded_step = self.state.global_step # Generate completions using vLLM: gather all prompts and use them in a single call in the main process all_inputs = gather_object(inputs) - if self.accelerator.is_main_process: - outputs = self.engine.infer(all_inputs, self.request_config, use_tqdm=False) + # Distribute inputs to different workers + # for example, 2 workers, 6 inputs, 0/2/4 dispatch to the first worker + # 1/3/5 dispatch to the second worker + # trying to shuffle and average the length + distributed_idx = self.round_robin(len(all_inputs), get_node_setting()[1] * self.args.num_infer_workers) + if self.infer_rank >= 0: + outputs = self.engine.infer( + np.array(all_inputs)[distributed_idx[self.infer_rank]], self.request_config, use_tqdm=False) else: - outputs = [None] * len(all_inputs) + outputs = [] + + outputs = gather_object(outputs) + outputs = self.reorder_outputs(outputs, distributed_idx) # Broadcast the completions from the main process to all processes, ensuring each process receives its # corresponding slice. - outputs = broadcast_object_list(outputs, from_process=0) + # outputs = broadcast_object_list(outputs, from_process=0) else: # Regular generation path is_multimodal = self.model.model_meta.is_multimodal diff --git a/swift/utils/__init__.py b/swift/utils/__init__.py index 1d6d7ffb6..3c254fd79 100644 --- a/swift/utils/__init__.py +++ b/swift/utils/__init__.py @@ -1,7 +1,8 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from .env import (get_dist_setting, get_pai_tensorboard_dir, is_deepspeed_enabled, is_dist, is_dist_ta, is_local_master, - is_master, is_mp, is_mp_ddp, is_pai_training_job, torchacc_trim_graph, use_hf_hub, use_torchacc) +from .env import (get_dist_setting, get_node_setting, get_pai_tensorboard_dir, is_deepspeed_enabled, is_dist, + is_dist_ta, is_local_master, is_master, is_mp, is_mp_ddp, is_pai_training_job, torchacc_trim_graph, + use_hf_hub, use_torchacc) from .import_utils import (is_liger_available, is_lmdeploy_available, is_megatron_available, is_swanlab_available, is_unsloth_available, is_vllm_available, is_wandb_available, is_xtuner_available) from .io_utils import JsonlWriter, append_to_jsonl, download_ms_file, get_file_mm_type, read_from_jsonl, write_to_jsonl diff --git a/swift/utils/env.py b/swift/utils/env.py index 513114988..29a6cbe79 100644 --- a/swift/utils/env.py +++ b/swift/utils/env.py @@ -33,6 +33,12 @@ def get_dist_setting() -> Tuple[int, int, int, int]: return rank, local_rank, world_size, local_world_size +def get_node_setting(): + node_rank = int(os.getenv('NODE_RANK', 0)) + nnodes = int(os.getenv('NNODES', 1)) + return node_rank, nnodes + + def is_local_master(): local_rank = get_dist_setting()[1] return local_rank in {-1, 0}