Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from tqdm.auto import tqdm
from typing_extensions import TypeVar, deprecated

import vllm.envs as envs
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence,
create_sort_beams_key_function)
Expand Down Expand Up @@ -44,9 +45,10 @@
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
PoolingRequestOutput, RequestOutput,
ScoringRequestOutput)
from vllm.pooling_params import PoolingParams, PoolingTask
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind, SamplingParams)
from vllm.tasks import PoolingTask
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
from vllm.usage.usage_lib import UsageContext
Expand Down Expand Up @@ -277,6 +279,16 @@ def __init__(
self.request_counter = Counter()
self.default_sampling_params: Union[dict[str, Any], None] = None

if envs.VLLM_USE_V1:
supported_tasks = self.llm_engine \
.get_supported_tasks() # type: ignore
else:
supported_tasks = self.llm_engine.model_config.supported_tasks

logger.info("Supported_tasks: %s", supported_tasks)

self.supported_tasks = supported_tasks

def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
Expand Down Expand Up @@ -1170,8 +1182,7 @@ def embed(
A list of `EmbeddingRequestOutput` objects containing the
embedding vectors in the same order as the input prompts.
"""
model_config = self.llm_engine.model_config
if "embed" not in model_config.supported_tasks:
if "embed" not in self.supported_tasks:
raise ValueError("Embedding API is not supported by this model. "
"Please set `--task embed`.")

Expand Down Expand Up @@ -1215,8 +1226,7 @@ def classify(
A list of `ClassificationRequestOutput` objects containing the
embedding vectors in the same order as the input prompts.
"""
model_config = self.llm_engine.model_config
if "classify" not in model_config.supported_tasks:
if "classify" not in self.supported_tasks:
raise ValueError(
"Classification API is not supported by this model. "
"Please set `--task classify`.")
Expand Down Expand Up @@ -1397,8 +1407,8 @@ def score(

raise ValueError(" ".join(messages))

if all(t not in model_config.supported_tasks
for t in ("embed", "classify")):
supported_tasks = self.supported_tasks
if all(t not in supported_tasks for t in ("embed", "classify")):
raise ValueError("Score API is not supported by this model. "
"Please set `--task embed` or `--task classify`.")

Expand Down
32 changes: 19 additions & 13 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1586,6 +1586,14 @@ async def init_app_state(
state.vllm_config = vllm_config
model_config = vllm_config.model_config

if envs.VLLM_USE_V1:
supported_tasks = await engine_client \
.get_supported_tasks() # type: ignore
else:
supported_tasks = model_config.supported_tasks

logger.info("Supported_tasks: %s", supported_tasks)

resolved_chat_template = load_chat_template(args.chat_template)
if resolved_chat_template is not None:
# Get the tokenizer to check official template
Expand Down Expand Up @@ -1647,7 +1655,7 @@ async def init_app_state(
reasoning_parser=args.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
) if "generate" in model_config.supported_tasks else None
) if "generate" in supported_tasks else None
state.openai_serving_chat = OpenAIServingChat(
engine_client,
model_config,
Expand All @@ -1664,7 +1672,7 @@ async def init_app_state(
reasoning_parser=args.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
) if "generate" in model_config.supported_tasks else None
) if "generate" in supported_tasks else None
state.openai_serving_completion = OpenAIServingCompletion(
engine_client,
model_config,
Expand All @@ -1673,40 +1681,38 @@ async def init_app_state(
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
) if "generate" in model_config.supported_tasks else None
) if "generate" in supported_tasks else None
state.openai_serving_pooling = OpenAIServingPooling(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
) if "encode" in model_config.supported_tasks else None
) if "encode" in supported_tasks else None
state.openai_serving_embedding = OpenAIServingEmbedding(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
) if "embed" in model_config.supported_tasks else None
) if "embed" in supported_tasks else None
state.openai_serving_classification = ServingClassification(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger,
) if "classify" in model_config.supported_tasks else None
) if "classify" in supported_tasks else None

enable_serving_reranking = ("classify" in model_config.supported_tasks
and getattr(model_config.hf_config,
"num_labels", 0) == 1)
enable_serving_reranking = ("classify" in supported_tasks and getattr(
model_config.hf_config, "num_labels", 0) == 1)
state.openai_serving_scores = ServingScores(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger,
) if ("embed" in model_config.supported_tasks
or enable_serving_reranking) else None
) if ("embed" in supported_tasks or enable_serving_reranking) else None

state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
Expand All @@ -1721,13 +1727,13 @@ async def init_app_state(
model_config,
state.openai_serving_models,
request_logger=request_logger,
) if "transcription" in model_config.supported_tasks else None
) if "transcription" in supported_tasks else None
state.openai_serving_translation = OpenAIServingTranslation(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger,
) if "transcription" in model_config.supported_tasks else None
) if "transcription" in supported_tasks else None
state.task = model_config.task

state.enable_server_load_tracking = args.enable_server_load_tracking
Expand Down
21 changes: 14 additions & 7 deletions vllm/entrypoints/openai/run_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from prometheus_client import start_http_server
from tqdm import tqdm

import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
from vllm.engine.protocol import EngineClient
Expand Down Expand Up @@ -335,6 +336,14 @@ async def run_batch(

model_config = vllm_config.model_config

if envs.VLLM_USE_V1:
supported_tasks = await engine_client \
.get_supported_tasks() # type: ignore
else:
supported_tasks = model_config.supported_tasks

logger.info("Supported_tasks: %s", supported_tasks)

# Create the openai serving objects.
openai_serving_models = OpenAIServingModels(
engine_client=engine_client,
Expand All @@ -351,27 +360,25 @@ async def run_batch(
chat_template=None,
chat_template_content_format="auto",
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
) if "generate" in model_config.supported_tasks else None
) if "generate" in supported_tasks else None
openai_serving_embedding = OpenAIServingEmbedding(
engine_client,
model_config,
openai_serving_models,
request_logger=request_logger,
chat_template=None,
chat_template_content_format="auto",
) if "embed" in model_config.supported_tasks else None
) if "embed" in supported_tasks else None

enable_serving_reranking = ("classify" in model_config.supported_tasks
and getattr(model_config.hf_config,
"num_labels", 0) == 1)
enable_serving_reranking = ("classify" in supported_tasks and getattr(
model_config.hf_config, "num_labels", 0) == 1)

openai_serving_scores = ServingScores(
engine_client,
model_config,
openai_serving_models,
request_logger=request_logger,
) if ("embed" in model_config.supported_tasks
or enable_serving_reranking) else None
) if ("embed" in supported_tasks or enable_serving_reranking) else None

tracker = BatchProgressTracker()
logger.info("Reading batch from %s...", args.input_file)
Expand Down
8 changes: 4 additions & 4 deletions vllm/executor/executor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.pooling_params import PoolingTask
from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.tasks import SupportedTask
from vllm.utils import make_async
from vllm.worker.worker_base import WorkerBase

Expand Down Expand Up @@ -136,9 +136,9 @@ def rpc_func(worker: WorkerBase) -> _R:
return self.collective_rpc(rpc_func)

@cached_property # Avoid unnecessary RPC calls
def supported_pooling_tasks(self) -> tuple[PoolingTask, ...]:
output = self.collective_rpc("get_supported_pooling_tasks")
return tuple({task for tasks in output for task in tasks})
def supported_tasks(self) -> tuple[SupportedTask, ...]:
output = self.collective_rpc("get_supported_tasks")
return output[0]

def execute_model(
self, execute_model_req: ExecuteModelRequest
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
from vllm.model_executor.pooling_metadata import ( # noqa: E501
PoolingMetadata as V0PoolingMetadata)
from vllm.model_executor.pooling_metadata import PoolingTensors
from vllm.pooling_params import PoolingParams, PoolingTask
from vllm.pooling_params import PoolingParams
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
from vllm.tasks import PoolingTask
from vllm.utils import resolve_obj_by_qualname
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata

Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingTask
from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask

from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/gritlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
get_prompt_token_ids)
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingTask
from vllm.sequence import PoolerOutput
from vllm.tasks import PoolingTask
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config

from .interfaces import SupportsV0Only
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/modernbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingTask
from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask

from .interfaces import SupportsCrossEncoding, SupportsV0Only
from .utils import WeightsMapper, maybe_prefix
Expand Down
5 changes: 2 additions & 3 deletions vllm/pooling_params.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import TYPE_CHECKING, Literal, Optional
from typing import TYPE_CHECKING, Optional

import msgspec

from vllm.sampling_params import RequestOutputKind
from vllm.tasks import PoolingTask

if TYPE_CHECKING:
from vllm.config import ModelConfig

PoolingTask = Literal["encode", "embed", "classify", "score"]


class PoolingParams(
msgspec.Struct,
Expand Down
11 changes: 11 additions & 0 deletions vllm/tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Literal, get_args

GenerationTask = Literal["generate", "transcription"]
GENERATION_TASKS = get_args(GenerationTask)

PoolingTask = Literal["encode", "embed", "classify", "score"]
POOLING_TASKS = get_args(PoolingTask)

SupportedTask = Literal[GenerationTask, PoolingTask]
4 changes: 4 additions & 0 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.transformers_utils.tokenizer import AnyTokenizer
Expand Down Expand Up @@ -211,6 +212,9 @@ def shutdown(self):
if handler := getattr(self, "output_handler", None):
handler.cancel()

async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return await self.engine_core.get_supported_tasks_async()

async def add_request(
self,
request_id: str,
Expand Down
11 changes: 9 additions & 2 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from vllm.logger import init_logger
from vllm.logging_utils.dump_input import dump_engine_exception
from vllm.lora.request import LoRARequest
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.utils import (bind_process_name, make_zmq_socket,
Expand Down Expand Up @@ -195,11 +196,17 @@ def _initialize_kv_caches(
"warmup model) took %.2f seconds"), elapsed)
return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config

def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.model_executor.supported_tasks

def add_request(self, request: EngineCoreRequest):
"""Add request to the scheduler."""
if pooling_params := request.pooling_params:
supported_pooling_tasks = (
self.model_executor.supported_pooling_tasks)
supported_pooling_tasks = [
task for task in self.get_supported_tasks()
if task in POOLING_TASKS
]

if pooling_params.task not in supported_pooling_tasks:
raise ValueError(f"Unsupported task: {pooling_params.task!r} "
f"Supported tasks: {supported_pooling_tasks}")
Expand Down
Loading