diff --git a/vllm_spyre/core/scheduler.py b/vllm_spyre/core/scheduler.py index 38adfa49a..2dc740c98 100644 --- a/vllm_spyre/core/scheduler.py +++ b/vllm_spyre/core/scheduler.py @@ -1267,7 +1267,6 @@ def schedule( multi_modal_placeholders=( seq_group.multi_modal_placeholders if scheduler_outputs.num_prefill_groups > 0 else None), - prompt_adapter_request=seq_group.prompt_adapter_request, ) else: # When SPMD mode is enabled, we only send delta data except for diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index 6ef9025c2..71139b470 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -3,7 +3,7 @@ from collections import deque from collections.abc import Iterable from dataclasses import asdict, dataclass -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, Literal, Optional, cast, get_args import torch from torch import nn @@ -34,6 +34,19 @@ from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput +############################################################# +# from vllm.tasks import GenerationTask, PoolingTask, SupportedTask +# TODO: remove when we have this in vllm/tasks.py +############################################################# +GenerationTask = Literal["generate", "transcription"] +GENERATION_TASKS = get_args(GenerationTask) + +PoolingTask = Literal["encode", "embed", "classify", "score"] +POOLING_TASKS = get_args(PoolingTask) + +SupportedTask = Literal[GenerationTask] +############################################################# + logger = init_logger(__name__) @@ -76,7 +89,6 @@ def __init__( self.scheduler_config = vllm_config.scheduler_config self.device_config = vllm_config.device_config self.speculative_config = vllm_config.speculative_config - self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config self.pad_token_id = 0 @@ -375,6 +387,14 @@ def prepare_model_input( else: return self._prepare_decode(scheduler_output.scheduled_cached_reqs) + def get_supported_tasks(self) -> tuple[SupportedTask, ...]: + tasks = list[SupportedTask]() + + if "generate" in self.model_config.supported_tasks: + tasks.extend(["generate"]) + + return tuple(tasks) + @SpyrePlatform.inference_mode() def execute_model( self, diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index fbe69374e..3ac799d14 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -29,7 +29,8 @@ from vllm_spyre.model_executor.model_loader import spyre_setup from vllm_spyre.platform import SpyrePlatform from vllm_spyre.v1.worker.spyre_model_runner import ( - ContinuousBatchingSpyreModelRunner, StaticBatchingSpyreModelRunner) + ContinuousBatchingSpyreModelRunner, StaticBatchingSpyreModelRunner, + SupportedTask) logger = init_logger(__name__) @@ -616,6 +617,9 @@ def do_metadata_broadcast(self) -> bool: def kv_cache(self) -> Optional[list[list[torch.Tensor]]]: return None + def get_supported_tasks(self) -> tuple[SupportedTask, ...]: + return self.model_runner.get_supported_tasks() + @SpyrePlatform.inference_mode() def execute_model( self, diff --git a/vllm_spyre/worker/spyre_embedding_model_runner.py b/vllm_spyre/worker/spyre_embedding_model_runner.py index fe5c830b3..3be4e67c3 100644 --- a/vllm_spyre/worker/spyre_embedding_model_runner.py +++ b/vllm_spyre/worker/spyre_embedding_model_runner.py @@ -42,11 +42,20 @@ def __init__( is_driver_worker=is_driver_worker) pooler_config = model_config.pooler_config - self.pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.CLS, - normalize=True, - softmax=False) + if hasattr(Pooler, "from_config_with_defaults"): + # TODO: remove this when we no longer support + # vllm version v0.9.2 + self.pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.CLS, + normalize=True, + softmax=False) + else: + self.pooler = Pooler.for_embed( + pooler_config=pooler_config, + default_pooling_type=PoolingType.CLS, + default_normalize=True, + default_softmax=False) def load_model(self, prompt_lens: Iterable[int], num_decode_tokens: Iterable[int]) -> None: