diff --git a/vllm/config.py b/vllm/config.py index 108badf150c86..356e0aabdf233 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -372,9 +372,9 @@ def verify_async_output_proc(self, parallel_config, speculative_config, self.use_async_output_proc = False return - if device_config.device_type not in ("cuda", "tpu"): + if device_config.device_type not in ("cuda", "tpu", "xpu"): logger.warning( - "Async output processing is only supported for CUDA or TPU. " + "Async output processing is only supported for CUDA, TPU, XPU. " "Disabling it for other platforms.") self.use_async_output_proc = False return diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index d3c763c995b34..a54422233f035 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -2,8 +2,8 @@ import time import weakref from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, - TypeVar) +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, + Type, TypeVar) import torch import torch.nn as nn @@ -56,6 +56,7 @@ class ModelInputForXPU(ModelRunnerInputBase): virtual_engine: Optional[int] = None seq_lens: Optional[List[int]] = None query_lens: Optional[List[int]] = None + async_callback: Optional[Callable] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { @@ -570,6 +571,9 @@ def execute_model( if not self.is_driver_worker: return [] + if model_input.async_callback is not None: + model_input.async_callback() + # Sample the next token. output: SamplerOutput = self.model.sample( logits=logits,