11# Copyright (c) Alibaba, Inc. and its affiliates.
22import os
3- from copy import copy , deepcopy
4- from typing import Any , Dict , Iterator , List , Optional , Union
3+ from typing import Any , Dict , Optional
54
65import torch
7- from packaging import version
86
9- from swift .llm import InferRequest , Template , VllmEngine , get_model_tokenizer
10- from swift .plugin import Metric
11- from ..protocol import ChatCompletionResponse , ChatCompletionStreamResponse , RequestConfig
12- from .patch import patch_auto_config , patch_auto_tokenizer
13- from .utils import AdapterRequest , patch_vllm_memory_leak
7+ from swift .llm import Template , VllmEngine
148
159try :
1610 # After setting the environment variables, import vllm. This way of writing allows lint to pass.
1711 os .environ ['VLLM_WORKER_MULTIPROC_METHOD' ] = 'spawn'
1812 os .environ ['VLLM_ENGINE_ITERATION_TIMEOUT_S' ] = '3600'
19- import vllm
20- from vllm import AsyncEngineArgs , AsyncLLMEngine , SamplingParams , EngineArgs , LLM
2113except Exception :
2214 raise
2315
@@ -56,23 +48,15 @@ def __init__(
5648 engine_kwargs : Optional [Dict [str , Any ]] = None ,
5749 template : Optional [Template ] = None ,
5850 ) -> None :
59- os .environ ['VLLM_USE_V1' ] = os .environ .get ('VLLM_USE_V1' , '0' )
60- if engine_kwargs is None :
61- engine_kwargs = {}
62- patch_vllm_memory_leak ()
63- self .use_async_engine = use_async_engine
64- self .processor = get_model_tokenizer (
65- model_id_or_path ,
66- torch_dtype ,
67- load_model = False ,
68- download_model = True ,
51+ assert not use_async_engine # TODO
52+ super ().__init__ (
53+ model_id_or_path = model_id_or_path ,
54+ torch_dtype = torch_dtype ,
55+ use_async_engine = use_async_engine ,
6956 model_type = model_type ,
7057 use_hf = use_hf ,
7158 hub_token = hub_token ,
72- revision = revision )[1 ]
73- self ._post_init (template )
74-
75- self ._prepare_engine_kwargs (
59+ revision = revision ,
7660 gpu_memory_utilization = gpu_memory_utilization ,
7761 tensor_parallel_size = tensor_parallel_size ,
7862 pipeline_parallel_size = pipeline_parallel_size ,
@@ -81,78 +65,15 @@ def __init__(
8165 disable_custom_all_reduce = disable_custom_all_reduce ,
8266 enforce_eager = enforce_eager ,
8367 limit_mm_per_prompt = limit_mm_per_prompt ,
68+ device = device ,
69+ seed = seed ,
8470 enable_lora = enable_lora ,
8571 max_loras = max_loras ,
8672 max_lora_rank = max_lora_rank ,
8773 enable_prefix_caching = enable_prefix_caching ,
88- device = device ,
89- seed = seed ,
90- distributed_executor_backend = distributed_executor_backend ,
9174 enable_sleep_mode = enable_sleep_mode ,
75+ distributed_executor_backend = distributed_executor_backend ,
9276 quantization = quantization ,
93- ** engine_kwargs ,
77+ engine_kwargs = engine_kwargs ,
78+ template = template ,
9479 )
95- self ._prepare_engine ()
96- self ._load_generation_config ()
97-
98- def _prepare_engine (self ) -> None :
99- with patch_auto_tokenizer (self .tokenizer ), patch_auto_config (self .config ):
100- engine = LLM (** self .engine_args .__dict__ )
101- self .engine = engine
102-
103- @property
104- def inner_model (self ):
105- return self .engine .llm_engine .model_executor .driver_worker .model_runner .model
106-
107- @property
108- def inner_model_executor (self ):
109- return self .engine .llm_engine .model_executor
110-
111- def infer (
112- self ,
113- infer_requests : List [InferRequest ],
114- request_config : Optional [RequestConfig ] = None ,
115- metrics : Optional [List [Metric ]] = None ,
116- * ,
117- template : Optional [Template ] = None ,
118- use_tqdm : Optional [bool ] = None ,
119- adapter_request : Optional [AdapterRequest ] = None ,
120- ) -> List [Union [ChatCompletionResponse , Iterator [ChatCompletionStreamResponse ]]]:
121- request_config = deepcopy (request_config or RequestConfig ())
122- if template is None :
123- template = self .default_template
124- template .set_mode ('vllm' )
125- batched_inputs , error_list = self ._batch_encode (
126- infer_requests , template = template , strict = getattr (self , 'strict' , True ))
127- self .set_default_max_tokens (request_config , batched_inputs )
128-
129- prompts = []
130- for inputs in batched_inputs :
131- llm_inputs = {'prompt_token_ids' : inputs ['input_ids' ]}
132- mm_data = {}
133- for key in ['images' , 'audios' , 'videos' ]:
134- media_data = inputs .get (key ) or []
135- if media_data :
136- if version .parse (vllm .__version__ ) < version .parse ('0.6' ):
137- assert len (media_data ) == 1 , (
138- f'The current version of vllm only supports single { key } . Please upgrade to vllm >= 0.6.0' )
139- mm_data = {key .rstrip ('s' ): media_data [0 ]}
140- else :
141- mm_data = {key .rstrip ('s' ): media_data [0 ] if len (media_data ) == 1 else media_data }
142- if mm_data :
143- llm_inputs ['multi_modal_data' ] = mm_data
144- prompts .append (llm_inputs )
145-
146- generation_configs = []
147- seed = request_config .seed
148- assert seed >= 0 , 'Seed is needed for GRPOVllmEngine.'
149- for i , _ in enumerate (prompts ):
150- request_config = copy (request_config )
151- request_config .seed = seed + i
152- generation_config = self ._prepare_generation_config (request_config )
153- self ._add_stop_words (generation_config , request_config , template .template_meta )
154- generation_configs .append (generation_config )
155- outputs = self .engine .generate (prompts , generation_configs , use_tqdm = False )
156- return [
157- self ._create_chat_completion_response (result , template , generation_configs [0 ], '' ) for result in outputs
158- ]
0 commit comments