Skip to content

Commit 7c75f50

Browse files
Fix API mismatch after PR 21585 (#43)
Mirroring changes from vllm-project/vllm#21585 to HPU code Signed-off-by: Konrad Zawora <kzawora@habana.ai>
1 parent 0cc8bb6 commit 7c75f50

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@
4646
from vllm.v1.worker.gpu_input_batch import CachedRequestState
4747
from vllm.distributed.parallel_state import get_pp_group
4848

49+
from vllm.model_executor.models.interfaces import supports_transcription
50+
from vllm.model_executor.models.interfaces_base import (
51+
is_pooling_model, is_text_generation_model)
52+
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
53+
4954
if TYPE_CHECKING:
5055
from vllm.v1.core.scheduler import SchedulerOutput
5156

@@ -2349,3 +2354,35 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
23492354
self._PAD_SLOT_ID = num_blocks * self.block_size
23502355

23512356
htorch.hpu.synchronize()
2357+
2358+
def get_supported_generation_tasks(self) -> list[GenerationTask]:
2359+
model = self.get_model()
2360+
supported_tasks = list[GenerationTask]()
2361+
2362+
if is_text_generation_model(model):
2363+
supported_tasks.append("generate")
2364+
2365+
if supports_transcription(model):
2366+
if model.supports_transcription_only:
2367+
return ["transcription"]
2368+
2369+
supported_tasks.append("transcription")
2370+
2371+
return supported_tasks
2372+
2373+
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
2374+
model = self.get_model()
2375+
if not is_pooling_model(model):
2376+
return []
2377+
2378+
return list(model.pooler.get_supported_tasks())
2379+
2380+
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
2381+
tasks = list[SupportedTask]()
2382+
2383+
if self.model_config.runner_type == "generate":
2384+
tasks.extend(self.get_supported_generation_tasks())
2385+
if self.model_config.runner_type == "pooling":
2386+
tasks.extend(self.get_supported_pooling_tasks())
2387+
2388+
return tuple(tasks)

vllm_gaudi/v1/worker/hpu_worker.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
import torch.distributed
1111
import torch.nn as nn
12+
from vllm.tasks import SupportedTask
1213
from vllm_gaudi.extension.profiler import HabanaMemoryProfiler, format_bytes
1314

1415
import vllm.envs as envs
@@ -230,6 +231,9 @@ def execute_model(
230231
# TODO(woosuk): Send the output to the engine process.
231232
return output if self.rank == 0 else None
232233

234+
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
235+
return self.model_runner.get_supported_tasks()
236+
233237

234238
def init_worker_distributed_environment(
235239
parallel_config: ParallelConfig,

0 commit comments

Comments
 (0)