Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion vllm_spyre/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,7 +1267,6 @@ def schedule(
multi_modal_placeholders=(
seq_group.multi_modal_placeholders
if scheduler_outputs.num_prefill_groups > 0 else None),
prompt_adapter_request=seq_group.prompt_adapter_request,
)
else:
# When SPMD mode is enabled, we only send delta data except for
Expand Down
24 changes: 22 additions & 2 deletions vllm_spyre/v1/worker/spyre_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections import deque
from collections.abc import Iterable
from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING, Optional, cast
from typing import TYPE_CHECKING, Literal, Optional, cast, get_args

import torch
from torch import nn
Expand Down Expand Up @@ -34,6 +34,19 @@

from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput

#############################################################
# from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
# TODO: remove when we have this in vllm/tasks.py
#############################################################
GenerationTask = Literal["generate", "transcription"]
GENERATION_TASKS = get_args(GenerationTask)

PoolingTask = Literal["encode", "embed", "classify", "score"]
POOLING_TASKS = get_args(PoolingTask)

SupportedTask = Literal[GenerationTask]
#############################################################

logger = init_logger(__name__)


Expand Down Expand Up @@ -76,7 +89,6 @@ def __init__(
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config

self.pad_token_id = 0
Expand Down Expand Up @@ -375,6 +387,14 @@ def prepare_model_input(
else:
return self._prepare_decode(scheduler_output.scheduled_cached_reqs)

def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
tasks = list[SupportedTask]()

if "generate" in self.model_config.supported_tasks:
tasks.extend(["generate"])
Comment on lines +393 to +394
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we would want this to be coming from the model directly:

if is_text_generation_model(model):
            supported_tasks.append("generate")

but SpyreCausalLM doesn't seem to support it atm

type(self.model)
<class 'vllm_spyre.model_executor.model_loader.spyre.SpyreCausalLM'>

is_text_generation_model(self.model)
False


return tuple(tasks)

@SpyrePlatform.inference_mode()
def execute_model(
self,
Expand Down
6 changes: 5 additions & 1 deletion vllm_spyre/v1/worker/spyre_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
from vllm_spyre.model_executor.model_loader import spyre_setup
from vllm_spyre.platform import SpyrePlatform
from vllm_spyre.v1.worker.spyre_model_runner import (
ContinuousBatchingSpyreModelRunner, StaticBatchingSpyreModelRunner)
ContinuousBatchingSpyreModelRunner, StaticBatchingSpyreModelRunner,
SupportedTask)

logger = init_logger(__name__)

Expand Down Expand Up @@ -616,6 +617,9 @@ def do_metadata_broadcast(self) -> bool:
def kv_cache(self) -> Optional[list[list[torch.Tensor]]]:
return None

def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.model_runner.get_supported_tasks()

@SpyrePlatform.inference_mode()
def execute_model(
self,
Expand Down
19 changes: 14 additions & 5 deletions vllm_spyre/worker/spyre_embedding_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,20 @@ def __init__(
is_driver_worker=is_driver_worker)

pooler_config = model_config.pooler_config
self.pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.CLS,
normalize=True,
softmax=False)
if hasattr(Pooler, "from_config_with_defaults"):
# TODO: remove this when we no longer support
# vllm version v0.9.2
self.pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.CLS,
normalize=True,
softmax=False)
else:
self.pooler = Pooler.for_embed(
pooler_config=pooler_config,
default_pooling_type=PoolingType.CLS,
default_normalize=True,
default_softmax=False)
Comment on lines +45 to +58
Copy link
Collaborator Author

@prashantgupta24 prashantgupta24 Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maxdebayser I think just need this piece of code from your PR to enable vllm:main breaking changes, I've removed all other pooling changes from this PR. I'm fine in waiting for your PR to get in first in which case I'll rebase and remove this change


def load_model(self, prompt_lens: Iterable[int],
num_decode_tokens: Iterable[int]) -> None:
Expand Down