Skip to content

Commit 4fc7e93

Browse files
DarkLight1337epwalsh
authored andcommitted
[V1] Get supported tasks from model runner instead of model config (vllm-project#21585)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent 7ffd56e commit 4fc7e93

File tree

19 files changed

+200
-54
lines changed

19 files changed

+200
-54
lines changed

vllm/entrypoints/llm.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from tqdm.auto import tqdm
1515
from typing_extensions import TypeVar, deprecated
1616

17+
import vllm.envs as envs
1718
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
1819
BeamSearchSequence,
1920
create_sort_beams_key_function)
@@ -44,9 +45,10 @@
4445
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
4546
PoolingRequestOutput, RequestOutput,
4647
ScoringRequestOutput)
47-
from vllm.pooling_params import PoolingParams, PoolingTask
48+
from vllm.pooling_params import PoolingParams
4849
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
4950
RequestOutputKind, SamplingParams)
51+
from vllm.tasks import PoolingTask
5052
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
5153
get_cached_tokenizer)
5254
from vllm.usage.usage_lib import UsageContext
@@ -277,6 +279,16 @@ def __init__(
277279
self.request_counter = Counter()
278280
self.default_sampling_params: Union[dict[str, Any], None] = None
279281

282+
if envs.VLLM_USE_V1:
283+
supported_tasks = self.llm_engine \
284+
.get_supported_tasks() # type: ignore
285+
else:
286+
supported_tasks = self.llm_engine.model_config.supported_tasks
287+
288+
logger.info("Supported_tasks: %s", supported_tasks)
289+
290+
self.supported_tasks = supported_tasks
291+
280292
def get_tokenizer(
281293
self,
282294
lora_request: Optional[LoRARequest] = None,
@@ -1170,8 +1182,7 @@ def embed(
11701182
A list of `EmbeddingRequestOutput` objects containing the
11711183
embedding vectors in the same order as the input prompts.
11721184
"""
1173-
model_config = self.llm_engine.model_config
1174-
if "embed" not in model_config.supported_tasks:
1185+
if "embed" not in self.supported_tasks:
11751186
raise ValueError("Embedding API is not supported by this model. "
11761187
"Please set `--task embed`.")
11771188

@@ -1215,8 +1226,7 @@ def classify(
12151226
A list of `ClassificationRequestOutput` objects containing the
12161227
embedding vectors in the same order as the input prompts.
12171228
"""
1218-
model_config = self.llm_engine.model_config
1219-
if "classify" not in model_config.supported_tasks:
1229+
if "classify" not in self.supported_tasks:
12201230
raise ValueError(
12211231
"Classification API is not supported by this model. "
12221232
"Please set `--task classify`.")
@@ -1397,8 +1407,8 @@ def score(
13971407

13981408
raise ValueError(" ".join(messages))
13991409

1400-
if all(t not in model_config.supported_tasks
1401-
for t in ("embed", "classify")):
1410+
supported_tasks = self.supported_tasks
1411+
if all(t not in supported_tasks for t in ("embed", "classify")):
14021412
raise ValueError("Score API is not supported by this model. "
14031413
"Please set `--task embed` or `--task classify`.")
14041414

vllm/entrypoints/openai/api_server.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1586,6 +1586,14 @@ async def init_app_state(
15861586
state.vllm_config = vllm_config
15871587
model_config = vllm_config.model_config
15881588

1589+
if envs.VLLM_USE_V1:
1590+
supported_tasks = await engine_client \
1591+
.get_supported_tasks() # type: ignore
1592+
else:
1593+
supported_tasks = model_config.supported_tasks
1594+
1595+
logger.info("Supported_tasks: %s", supported_tasks)
1596+
15891597
resolved_chat_template = load_chat_template(args.chat_template)
15901598
if resolved_chat_template is not None:
15911599
# Get the tokenizer to check official template
@@ -1647,7 +1655,7 @@ async def init_app_state(
16471655
reasoning_parser=args.reasoning_parser,
16481656
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
16491657
enable_force_include_usage=args.enable_force_include_usage,
1650-
) if "generate" in model_config.supported_tasks else None
1658+
) if "generate" in supported_tasks else None
16511659
state.openai_serving_chat = OpenAIServingChat(
16521660
engine_client,
16531661
model_config,
@@ -1664,7 +1672,7 @@ async def init_app_state(
16641672
reasoning_parser=args.reasoning_parser,
16651673
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
16661674
enable_force_include_usage=args.enable_force_include_usage,
1667-
) if "generate" in model_config.supported_tasks else None
1675+
) if "generate" in supported_tasks else None
16681676
state.openai_serving_completion = OpenAIServingCompletion(
16691677
engine_client,
16701678
model_config,
@@ -1673,40 +1681,38 @@ async def init_app_state(
16731681
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
16741682
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
16751683
enable_force_include_usage=args.enable_force_include_usage,
1676-
) if "generate" in model_config.supported_tasks else None
1684+
) if "generate" in supported_tasks else None
16771685
state.openai_serving_pooling = OpenAIServingPooling(
16781686
engine_client,
16791687
model_config,
16801688
state.openai_serving_models,
16811689
request_logger=request_logger,
16821690
chat_template=resolved_chat_template,
16831691
chat_template_content_format=args.chat_template_content_format,
1684-
) if "encode" in model_config.supported_tasks else None
1692+
) if "encode" in supported_tasks else None
16851693
state.openai_serving_embedding = OpenAIServingEmbedding(
16861694
engine_client,
16871695
model_config,
16881696
state.openai_serving_models,
16891697
request_logger=request_logger,
16901698
chat_template=resolved_chat_template,
16911699
chat_template_content_format=args.chat_template_content_format,
1692-
) if "embed" in model_config.supported_tasks else None
1700+
) if "embed" in supported_tasks else None
16931701
state.openai_serving_classification = ServingClassification(
16941702
engine_client,
16951703
model_config,
16961704
state.openai_serving_models,
16971705
request_logger=request_logger,
1698-
) if "classify" in model_config.supported_tasks else None
1706+
) if "classify" in supported_tasks else None
16991707

1700-
enable_serving_reranking = ("classify" in model_config.supported_tasks
1701-
and getattr(model_config.hf_config,
1702-
"num_labels", 0) == 1)
1708+
enable_serving_reranking = ("classify" in supported_tasks and getattr(
1709+
model_config.hf_config, "num_labels", 0) == 1)
17031710
state.openai_serving_scores = ServingScores(
17041711
engine_client,
17051712
model_config,
17061713
state.openai_serving_models,
17071714
request_logger=request_logger,
1708-
) if ("embed" in model_config.supported_tasks
1709-
or enable_serving_reranking) else None
1715+
) if ("embed" in supported_tasks or enable_serving_reranking) else None
17101716

17111717
state.openai_serving_tokenization = OpenAIServingTokenization(
17121718
engine_client,
@@ -1721,13 +1727,13 @@ async def init_app_state(
17211727
model_config,
17221728
state.openai_serving_models,
17231729
request_logger=request_logger,
1724-
) if "transcription" in model_config.supported_tasks else None
1730+
) if "transcription" in supported_tasks else None
17251731
state.openai_serving_translation = OpenAIServingTranslation(
17261732
engine_client,
17271733
model_config,
17281734
state.openai_serving_models,
17291735
request_logger=request_logger,
1730-
) if "transcription" in model_config.supported_tasks else None
1736+
) if "transcription" in supported_tasks else None
17311737
state.task = model_config.task
17321738

17331739
state.enable_server_load_tracking = args.enable_server_load_tracking

vllm/entrypoints/openai/run_batch.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from prometheus_client import start_http_server
1515
from tqdm import tqdm
1616

17+
import vllm.envs as envs
1718
from vllm.config import VllmConfig
1819
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
1920
from vllm.engine.protocol import EngineClient
@@ -335,6 +336,14 @@ async def run_batch(
335336

336337
model_config = vllm_config.model_config
337338

339+
if envs.VLLM_USE_V1:
340+
supported_tasks = await engine_client \
341+
.get_supported_tasks() # type: ignore
342+
else:
343+
supported_tasks = model_config.supported_tasks
344+
345+
logger.info("Supported_tasks: %s", supported_tasks)
346+
338347
# Create the openai serving objects.
339348
openai_serving_models = OpenAIServingModels(
340349
engine_client=engine_client,
@@ -351,27 +360,25 @@ async def run_batch(
351360
chat_template=None,
352361
chat_template_content_format="auto",
353362
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
354-
) if "generate" in model_config.supported_tasks else None
363+
) if "generate" in supported_tasks else None
355364
openai_serving_embedding = OpenAIServingEmbedding(
356365
engine_client,
357366
model_config,
358367
openai_serving_models,
359368
request_logger=request_logger,
360369
chat_template=None,
361370
chat_template_content_format="auto",
362-
) if "embed" in model_config.supported_tasks else None
371+
) if "embed" in supported_tasks else None
363372

364-
enable_serving_reranking = ("classify" in model_config.supported_tasks
365-
and getattr(model_config.hf_config,
366-
"num_labels", 0) == 1)
373+
enable_serving_reranking = ("classify" in supported_tasks and getattr(
374+
model_config.hf_config, "num_labels", 0) == 1)
367375

368376
openai_serving_scores = ServingScores(
369377
engine_client,
370378
model_config,
371379
openai_serving_models,
372380
request_logger=request_logger,
373-
) if ("embed" in model_config.supported_tasks
374-
or enable_serving_reranking) else None
381+
) if ("embed" in supported_tasks or enable_serving_reranking) else None
375382

376383
tracker = BatchProgressTracker()
377384
logger.info("Reading batch from %s...", args.input_file)

vllm/executor/executor_base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
from vllm.logger import init_logger
1717
from vllm.lora.request import LoRARequest
1818
from vllm.model_executor.layers.sampler import SamplerOutput
19-
from vllm.pooling_params import PoolingTask
2019
from vllm.sequence import ExecuteModelRequest, PoolerOutput
20+
from vllm.tasks import SupportedTask
2121
from vllm.utils import make_async
2222
from vllm.worker.worker_base import WorkerBase
2323

@@ -136,9 +136,9 @@ def rpc_func(worker: WorkerBase) -> _R:
136136
return self.collective_rpc(rpc_func)
137137

138138
@cached_property # Avoid unnecessary RPC calls
139-
def supported_pooling_tasks(self) -> tuple[PoolingTask, ...]:
140-
output = self.collective_rpc("get_supported_pooling_tasks")
141-
return tuple({task for tasks in output for task in tasks})
139+
def supported_tasks(self) -> tuple[SupportedTask, ...]:
140+
output = self.collective_rpc("get_supported_tasks")
141+
return output[0]
142142

143143
def execute_model(
144144
self, execute_model_req: ExecuteModelRequest

vllm/model_executor/layers/pooler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
from vllm.model_executor.pooling_metadata import ( # noqa: E501
1717
PoolingMetadata as V0PoolingMetadata)
1818
from vllm.model_executor.pooling_metadata import PoolingTensors
19-
from vllm.pooling_params import PoolingParams, PoolingTask
19+
from vllm.pooling_params import PoolingParams
2020
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
21+
from vllm.tasks import PoolingTask
2122
from vllm.utils import resolve_obj_by_qualname
2223
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
2324

vllm/model_executor/models/bert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
from vllm.model_executor.layers.vocab_parallel_embedding import (
2727
VocabParallelEmbedding)
2828
from vllm.model_executor.pooling_metadata import PoolingMetadata
29-
from vllm.pooling_params import PoolingTask
3029
from vllm.sequence import IntermediateTensors
30+
from vllm.tasks import PoolingTask
3131

3232
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
3333
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix

vllm/model_executor/models/gritlm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
get_prompt_token_ids)
1717
from vllm.model_executor.models.llama import LlamaForCausalLM
1818
from vllm.model_executor.pooling_metadata import PoolingMetadata
19-
from vllm.pooling_params import PoolingTask
2019
from vllm.sequence import PoolerOutput
20+
from vllm.tasks import PoolingTask
2121
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
2222

2323
from .interfaces import SupportsV0Only

vllm/model_executor/models/modernbert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
VocabParallelEmbedding)
2424
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
2525
from vllm.model_executor.pooling_metadata import PoolingMetadata
26-
from vllm.pooling_params import PoolingTask
2726
from vllm.sequence import IntermediateTensors
27+
from vllm.tasks import PoolingTask
2828

2929
from .interfaces import SupportsCrossEncoding, SupportsV0Only
3030
from .utils import WeightsMapper, maybe_prefix

vllm/pooling_params.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
from typing import TYPE_CHECKING, Literal, Optional
4+
from typing import TYPE_CHECKING, Optional
55

66
import msgspec
77

88
from vllm.sampling_params import RequestOutputKind
9+
from vllm.tasks import PoolingTask
910

1011
if TYPE_CHECKING:
1112
from vllm.config import ModelConfig
1213

13-
PoolingTask = Literal["encode", "embed", "classify", "score"]
14-
1514

1615
class PoolingParams(
1716
msgspec.Struct,

vllm/tasks.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from typing import Literal, get_args
4+
5+
GenerationTask = Literal["generate", "transcription"]
6+
GENERATION_TASKS = get_args(GenerationTask)
7+
8+
PoolingTask = Literal["encode", "embed", "classify", "score"]
9+
POOLING_TASKS = get_args(PoolingTask)
10+
11+
SupportedTask = Literal[GenerationTask, PoolingTask]

0 commit comments

Comments
 (0)