Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from vllm.v1.worker.gpu_input_batch import InputBatch

from vllm_ascend.ops.attention import vanilla_chunked_prefill
from vllm_ascend.utils import vllm_version_is


class AscendAttentionBackend(AttentionBackend):
Expand Down Expand Up @@ -141,8 +142,14 @@ def reorder_batch(self, input_batch: "InputBatch",

def build(self, num_reqs, num_actual_tokens, max_query_len,
common_prefix_len):
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
block_table = (self.runner.input_batch.block_table.
get_device_tensor()[:num_reqs])
else:
block_table = self.runner.input_batch.block_table[
0].get_device_tensor()
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
block_table[:num_reqs])

query_lens = self.runner.query_lens
seq_lens = self.runner.seq_lens_cpu[:num_reqs]
Expand Down
11 changes: 9 additions & 2 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
from vllm_ascend.utils import vllm_version_is
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner

if TYPE_CHECKING:
Expand Down Expand Up @@ -238,8 +239,14 @@ def build(self,
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
device = self.runner.device
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
block_table = (self.runner.input_batch.block_table.
get_device_tensor()[:num_reqs])
else:
block_table = self.runner.input_batch.block_table[
0].get_device_tensor()
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
block_table[:num_reqs])
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
device, non_blocking=True)
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
Expand Down
44 changes: 15 additions & 29 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config
Expand Down Expand Up @@ -172,24 +173,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
raise NotImplementedError(
"Non-Attention backend is not supported by V1 NPUModelRunner.")

self.attn_backend = get_attn_backend(
self.head_size,
self.dtype,
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
use_mla=self.model_config.use_mla,
)
if self.attn_backend is None:
error_msg = (
f"Error with get_att_backend: {self.head_size=}, "
f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, "
f"{self.model_config.is_attention_free=}, "
f"{self.model_config.use_mla=}")
logger.error(error_msg)
raise NotImplementedError(
"Non-Attention backend is not supported by V1 GPUModelRunner.")

self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
weakref.proxy(self))

Expand Down Expand Up @@ -237,16 +220,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
pin_memory=True,
vocab_size=self.model_config.get_vocab_size(),
)
else:
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.model_config.max_model_len,
max_num_blocks_per_req=self.max_num_blocks_per_req,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=True,
vocab_size=self.model_config.get_vocab_size(),
)

self.input_ids = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
Expand Down Expand Up @@ -600,7 +573,10 @@ def _process_reqs(

block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions_np // self.block_size)
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
else:
block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
block_offsets = positions_np % self.block_size
np.add(block_numbers * self.block_size,
Expand Down Expand Up @@ -1206,6 +1182,16 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
import torch_npu
kv_caches: Dict[str, torch.Tensor] = {}
if not (vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1")):
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.model_config.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=True,
vocab_size=self.model_config.get_vocab_size(),
block_size=self.cache_config.block_size,
)

for kv_cache_group in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group.kv_cache_spec
Expand Down