Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ norecursedirs =
vllm-empty/tests/neuron
; fastsafetensors not support npu now
vllm-empty/tests/fastsafetensors_loader
; Enable after https://github.com/vllm-project/vllm-ascend/issues/808 resolved
vllm-empty/tests/benchmarks

addopts = --ignore=vllm-empty/tests/test_utils.py
--ignore=vllm-empty/tests/test_config.py
Expand Down
29 changes: 21 additions & 8 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from vllm_ascend.attention.attention import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import vllm_version_is

if TYPE_CHECKING:
import xgrammar as xgr # type: ignore[import-untyped]
Expand Down Expand Up @@ -187,14 +188,26 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
# Request states.
self.requests: Dict[str, CachedRequestState] = {}
# Persistent batch.
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.model_config.max_model_len,
max_num_blocks_per_req=self.max_num_blocks_per_req,
device=self.device,
pin_memory=True,
vocab_size=self.model_config.get_vocab_size(),
)
# Remove this after we drop 0.8.5 support
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.model_config.max_model_len,
max_num_blocks_per_req=self.max_num_blocks_per_req,
device=self.device,
pin_memory=True,
vocab_size=self.model_config.get_vocab_size(),
)
else:
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.model_config.max_model_len,
max_num_blocks_per_req=self.max_num_blocks_per_req,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=True,
vocab_size=self.model_config.get_vocab_size(),
)

self.input_ids = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
Expand Down