Skip to content

Commit

Permalink
add aysnc output process for xpu
Browse files Browse the repository at this point in the history
  • Loading branch information
jikunshang committed Sep 27, 2024
1 parent 7193774 commit 147e9c5
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
4 changes: 2 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,9 @@ def verify_async_output_proc(self, parallel_config, speculative_config,
self.use_async_output_proc = False
return

if device_config.device_type not in ("cuda", "tpu"):
if device_config.device_type not in ("cuda", "tpu", "xpu"):
logger.warning(
"Async output processing is only supported for CUDA or TPU. "
"Async output processing is only supported for CUDA, TPU, XPU. "
"Disabling it for other platforms.")
self.use_async_output_proc = False
return
Expand Down
8 changes: 6 additions & 2 deletions vllm/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import time
import weakref
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type,
TypeVar)
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Type, TypeVar)

import torch
import torch.nn as nn
Expand Down Expand Up @@ -56,6 +56,7 @@ class ModelInputForXPU(ModelRunnerInputBase):
virtual_engine: Optional[int] = None
seq_lens: Optional[List[int]] = None
query_lens: Optional[List[int]] = None
async_callback: Optional[Callable] = None

def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
Expand Down Expand Up @@ -570,6 +571,9 @@ def execute_model(
if not self.is_driver_worker:
return []

if model_input.async_callback is not None:
model_input.async_callback()

# Sample the next token.
output: SamplerOutput = self.model.sample(
logits=logits,
Expand Down

0 comments on commit 147e9c5

Please sign in to comment.