Skip to content

Commit 6dc42ab

Browse files
authored
[grpo] Refactor GRPOVllmEngine (#4375)
* update * memory leak and env setting * fix v1 tp=1
1 parent 3b668d7 commit 6dc42ab

File tree

4 files changed

+17
-100
lines changed

4 files changed

+17
-100
lines changed
Lines changed: 13 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,15 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22
import os
3-
from copy import copy, deepcopy
4-
from typing import Any, Dict, Iterator, List, Optional, Union
3+
from typing import Any, Dict, Optional
54

65
import torch
7-
from packaging import version
86

9-
from swift.llm import InferRequest, Template, VllmEngine, get_model_tokenizer
10-
from swift.plugin import Metric
11-
from ..protocol import ChatCompletionResponse, ChatCompletionStreamResponse, RequestConfig
12-
from .patch import patch_auto_config, patch_auto_tokenizer
13-
from .utils import AdapterRequest, patch_vllm_memory_leak
7+
from swift.llm import Template, VllmEngine
148

159
try:
1610
# After setting the environment variables, import vllm. This way of writing allows lint to pass.
1711
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
1812
os.environ['VLLM_ENGINE_ITERATION_TIMEOUT_S'] = '3600'
19-
import vllm
20-
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams, EngineArgs, LLM
2113
except Exception:
2214
raise
2315

@@ -56,23 +48,15 @@ def __init__(
5648
engine_kwargs: Optional[Dict[str, Any]] = None,
5749
template: Optional[Template] = None,
5850
) -> None:
59-
os.environ['VLLM_USE_V1'] = os.environ.get('VLLM_USE_V1', '0')
60-
if engine_kwargs is None:
61-
engine_kwargs = {}
62-
patch_vllm_memory_leak()
63-
self.use_async_engine = use_async_engine
64-
self.processor = get_model_tokenizer(
65-
model_id_or_path,
66-
torch_dtype,
67-
load_model=False,
68-
download_model=True,
51+
assert not use_async_engine # TODO
52+
super().__init__(
53+
model_id_or_path=model_id_or_path,
54+
torch_dtype=torch_dtype,
55+
use_async_engine=use_async_engine,
6956
model_type=model_type,
7057
use_hf=use_hf,
7158
hub_token=hub_token,
72-
revision=revision)[1]
73-
self._post_init(template)
74-
75-
self._prepare_engine_kwargs(
59+
revision=revision,
7660
gpu_memory_utilization=gpu_memory_utilization,
7761
tensor_parallel_size=tensor_parallel_size,
7862
pipeline_parallel_size=pipeline_parallel_size,
@@ -81,78 +65,15 @@ def __init__(
8165
disable_custom_all_reduce=disable_custom_all_reduce,
8266
enforce_eager=enforce_eager,
8367
limit_mm_per_prompt=limit_mm_per_prompt,
68+
device=device,
69+
seed=seed,
8470
enable_lora=enable_lora,
8571
max_loras=max_loras,
8672
max_lora_rank=max_lora_rank,
8773
enable_prefix_caching=enable_prefix_caching,
88-
device=device,
89-
seed=seed,
90-
distributed_executor_backend=distributed_executor_backend,
9174
enable_sleep_mode=enable_sleep_mode,
75+
distributed_executor_backend=distributed_executor_backend,
9276
quantization=quantization,
93-
**engine_kwargs,
77+
engine_kwargs=engine_kwargs,
78+
template=template,
9479
)
95-
self._prepare_engine()
96-
self._load_generation_config()
97-
98-
def _prepare_engine(self) -> None:
99-
with patch_auto_tokenizer(self.tokenizer), patch_auto_config(self.config):
100-
engine = LLM(**self.engine_args.__dict__)
101-
self.engine = engine
102-
103-
@property
104-
def inner_model(self):
105-
return self.engine.llm_engine.model_executor.driver_worker.model_runner.model
106-
107-
@property
108-
def inner_model_executor(self):
109-
return self.engine.llm_engine.model_executor
110-
111-
def infer(
112-
self,
113-
infer_requests: List[InferRequest],
114-
request_config: Optional[RequestConfig] = None,
115-
metrics: Optional[List[Metric]] = None,
116-
*,
117-
template: Optional[Template] = None,
118-
use_tqdm: Optional[bool] = None,
119-
adapter_request: Optional[AdapterRequest] = None,
120-
) -> List[Union[ChatCompletionResponse, Iterator[ChatCompletionStreamResponse]]]:
121-
request_config = deepcopy(request_config or RequestConfig())
122-
if template is None:
123-
template = self.default_template
124-
template.set_mode('vllm')
125-
batched_inputs, error_list = self._batch_encode(
126-
infer_requests, template=template, strict=getattr(self, 'strict', True))
127-
self.set_default_max_tokens(request_config, batched_inputs)
128-
129-
prompts = []
130-
for inputs in batched_inputs:
131-
llm_inputs = {'prompt_token_ids': inputs['input_ids']}
132-
mm_data = {}
133-
for key in ['images', 'audios', 'videos']:
134-
media_data = inputs.get(key) or []
135-
if media_data:
136-
if version.parse(vllm.__version__) < version.parse('0.6'):
137-
assert len(media_data) == 1, (
138-
f'The current version of vllm only supports single {key}. Please upgrade to vllm >= 0.6.0')
139-
mm_data = {key.rstrip('s'): media_data[0]}
140-
else:
141-
mm_data = {key.rstrip('s'): media_data[0] if len(media_data) == 1 else media_data}
142-
if mm_data:
143-
llm_inputs['multi_modal_data'] = mm_data
144-
prompts.append(llm_inputs)
145-
146-
generation_configs = []
147-
seed = request_config.seed
148-
assert seed >= 0, 'Seed is needed for GRPOVllmEngine.'
149-
for i, _ in enumerate(prompts):
150-
request_config = copy(request_config)
151-
request_config.seed = seed + i
152-
generation_config = self._prepare_generation_config(request_config)
153-
self._add_stop_words(generation_config, request_config, template.template_meta)
154-
generation_configs.append(generation_config)
155-
outputs = self.engine.generate(prompts, generation_configs, use_tqdm=False)
156-
return [
157-
self._create_chat_completion_response(result, template, generation_configs[0], '') for result in outputs
158-
]

swift/llm/infer/infer_engine/vllm_engine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, RequestConfig, random_uuid)
2020
from .infer_engine import InferEngine
2121
from .patch import patch_auto_config, patch_auto_tokenizer
22-
from .utils import AdapterRequest, InferStreamer, patch_npu_vllm
22+
from .utils import AdapterRequest, InferStreamer, patch_npu_vllm, patch_vllm_memory_leak
2323

2424
try:
2525
# After setting the environment variables, import vllm. This way of writing allows lint to pass.
@@ -70,6 +70,7 @@ def __init__(
7070
) -> None:
7171
if engine_kwargs is None:
7272
engine_kwargs = {}
73+
patch_vllm_memory_leak()
7374
self.use_async_engine = use_async_engine
7475
self.processor = get_model_tokenizer(
7576
model_id_or_path,

swift/llm/infer/rollout.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ def _register_rl_rollout_app(self):
8686
self.app.post('/infer/', response_model=None)(self.infer)
8787

8888
def __init__(self, args: Union[List[str], DeployArguments, None] = None):
89-
os.environ['VLLM_USE_V1'] = os.environ.get('VLLM_USE_V1', '1')
9089
super().__init__(args)
9190
safe_set_start_method()
9291
self.app = FastAPI(lifespan=self.lifespan)

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -411,11 +411,6 @@ def split_llm(name):
411411
def prepare_vllm(self, model):
412412
from swift.tuners import Swift
413413
from swift.llm.infer.infer_engine import GRPOVllmEngine
414-
if self.vllm_tensor_parallel_size > 1:
415-
vllm_kwargs = {'distributed_executor_backend': 'external_launcher'}
416-
else:
417-
vllm_kwargs = {}
418-
419414
max_num_seqs = (
420415
self.args.per_device_train_batch_size * self.vllm_tensor_parallel_size
421416
* self.args.gradient_accumulation_steps)
@@ -436,7 +431,8 @@ def prepare_vllm(self, model):
436431
max_model_len=self.args.vllm_max_model_len,
437432
seed=self.accelerator.process_index // self.vllm_tensor_parallel_size,
438433
template=self.template,
439-
**vllm_kwargs)
434+
distributed_executor_backend='external_launcher',
435+
)
440436
return engine
441437

442438
@contextmanager

0 commit comments

Comments
 (0)