Skip to content

Commit 4960d65

Browse files
committed
move out v1 attn
Signed-off-by: jiang.li <jiang1.li@intel.com>
1 parent 3e5b910 commit 4960d65

File tree

3 files changed

+165
-157
lines changed

3 files changed

+165
-157
lines changed

vllm/attention/backends/torch_sdpa.py

Lines changed: 0 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from dataclasses import dataclass
55
from typing import Any, Dict, List, Optional, Tuple, Type
66

7-
import numpy as np
87
import torch
98
from torch.nn.functional import scaled_dot_product_attention
109

@@ -22,9 +21,6 @@
2221
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
2322
from vllm.logger import init_logger
2423
from vllm.utils import make_tensor_with_pad
25-
from vllm.v1.core.sched.output import SchedulerOutput
26-
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
27-
from vllm.v1.worker.gpu_input_batch import InputBatch
2824
from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder
2925

3026
logger = init_logger(__name__)
@@ -78,59 +74,6 @@ def copy_blocks(
7874
PagedAttention.copy_blocks(kv_caches, src_to_dists)
7975

8076

81-
class TorchSDPABackendV1:
82-
accept_output_buffer: bool = False
83-
84-
@staticmethod
85-
def get_name() -> str:
86-
return "TORCH_SDPA"
87-
88-
@staticmethod
89-
def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
90-
return TorchSDPABackendImpl
91-
92-
@staticmethod
93-
def get_metadata_cls() -> Type["AttentionMetadata"]:
94-
return TorchSDPAMetadata
95-
96-
@staticmethod
97-
def get_state_cls() -> Type["CommonAttentionState"]:
98-
return CommonAttentionState
99-
100-
@staticmethod
101-
def get_builder_cls() -> Type["TorchSDPAMetadataBuilderV1"]:
102-
return TorchSDPAMetadataBuilderV1
103-
104-
@staticmethod
105-
def get_kv_cache_shape(
106-
num_blocks: int,
107-
block_size: int,
108-
num_kv_heads: int,
109-
head_size: int,
110-
) -> Tuple[int, ...]:
111-
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
112-
num_kv_heads, head_size)
113-
114-
@staticmethod
115-
def swap_blocks(
116-
src_kv_cache: torch.Tensor,
117-
dst_kv_cache: torch.Tensor,
118-
src_to_dst: torch.Tensor,
119-
) -> None:
120-
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
121-
122-
@staticmethod
123-
def copy_blocks(
124-
kv_caches: List[torch.Tensor],
125-
src_to_dists: torch.Tensor,
126-
) -> None:
127-
PagedAttention.copy_blocks(kv_caches, src_to_dists)
128-
129-
@staticmethod
130-
def use_cascade_attention(*args, **kwargs) -> bool:
131-
return False
132-
133-
13477
@dataclass
13578
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
13679
"""Metadata for TorchSDPABackend.
@@ -343,105 +286,6 @@ def get_seq_len_block_table_args(
343286
raise AttributeError(f"Invalid attention type {str(attn_type)}")
344287

345288

346-
class TorchSDPAMetadataBuilderV1:
347-
348-
def __init__(self, runner: CPUModelRunner) -> None:
349-
self.runner = runner
350-
351-
# For reorder
352-
self.reorder_prompt_req_index_list = np.empty(self.runner.max_num_reqs,
353-
dtype=np.int64)
354-
self.reorder_decode_req_index_list = np.empty(self.runner.max_num_reqs,
355-
dtype=np.int64)
356-
self.num_prompt_req: int = 0
357-
358-
def reorder_batch(self, input_batch: InputBatch,
359-
scheduler_output: SchedulerOutput) -> bool:
360-
prompt_list_idx = 0
361-
decode_list_idx = 0
362-
for req_index in range(input_batch.num_reqs):
363-
if input_batch.num_computed_tokens_cpu[
364-
req_index] < input_batch.num_prompt_tokens[req_index]:
365-
# prompt stage
366-
self.reorder_prompt_req_index_list[prompt_list_idx] = req_index
367-
prompt_list_idx += 1
368-
else:
369-
# decode stage
370-
self.reorder_decode_req_index_list[decode_list_idx] = req_index
371-
decode_list_idx += 1
372-
assert decode_list_idx + prompt_list_idx == input_batch.num_reqs
373-
374-
# Update prompt requests number
375-
self.num_prompt_req = prompt_list_idx
376-
377-
reorder_req_num = 0
378-
for req_index in range(decode_list_idx):
379-
if self.reorder_decode_req_index_list[req_index] < prompt_list_idx:
380-
reorder_req_num += 1
381-
else:
382-
break
383-
384-
if reorder_req_num == 0:
385-
return False
386-
387-
reorder_prompt_list = (
388-
self.reorder_prompt_req_index_list[:prompt_list_idx]
389-
[-reorder_req_num:])
390-
reorder_decode_list = (
391-
self.reorder_decode_req_index_list[:decode_list_idx]
392-
[:reorder_req_num])
393-
assert reorder_decode_list.size == reorder_prompt_list.size
394-
395-
for idx in range(reorder_req_num):
396-
prompt_req_index = reorder_prompt_list[idx].item()
397-
decode_req_index = reorder_decode_list[idx].item()
398-
input_batch.swap_states(prompt_req_index, decode_req_index)
399-
400-
return True
401-
402-
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
403-
common_prefix_len: int):
404-
runner = self.runner
405-
seq_lens_np = runner.seq_lens_np[:num_reqs]
406-
num_prompt_req = self.num_prompt_req
407-
max_prefill_seq_len = seq_lens_np[:num_prompt_req].max().item(
408-
) if num_prompt_req > 0 else 0
409-
max_decode_seq_len = seq_lens_np[num_prompt_req:num_reqs].max().item(
410-
) if num_prompt_req < num_reqs else 0
411-
runner.seq_start_loc_np[0] = 0
412-
np.cumsum(seq_lens_np, out=runner.seq_start_loc_np[1:num_reqs + 1])
413-
num_prefill_tokens = runner.query_start_loc_np[num_prompt_req].item()
414-
num_decode_tokens = runner.query_start_loc_np[num_reqs].item(
415-
) - num_prefill_tokens
416-
slot_mapping = runner.slot_mapping_cpu[:num_actual_tokens].long()
417-
block_table_tensor = runner.input_batch.block_table.get_device_tensor()
418-
attn_metadata = TorchSDPAMetadata(
419-
num_prefills=num_prompt_req,
420-
num_prefill_tokens=num_prefill_tokens,
421-
num_decode_tokens=num_decode_tokens,
422-
slot_mapping=slot_mapping,
423-
seq_lens_tensor=runner.
424-
seq_lens_cpu[num_prompt_req:num_reqs], # decode
425-
max_decode_seq_len=max_decode_seq_len, # decode
426-
block_tables=block_table_tensor[num_prompt_req:num_reqs], # decode
427-
chunked_prefill=True,
428-
max_query_len=max_query_len,
429-
max_kv_len=max_prefill_seq_len,
430-
prefill_query_start_loc=runner.
431-
query_start_loc_cpu[:num_prompt_req + 1], # prefill
432-
kv_start_loc=runner.seq_start_loc_cpu[:num_prompt_req +
433-
1], # prefill
434-
prefill_block_tables=block_table_tensor[:
435-
num_prompt_req], # prefill
436-
query_start_loc=runner.query_start_loc_cpu[:num_reqs +
437-
1], # for logits index
438-
multi_modal_placeholder_index_maps=None,
439-
enable_kv_scales_calculation=False,
440-
)
441-
442-
return attn_metadata
443-
444-
445289
class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]):
446290

447291
def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:

vllm/platforms/cpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
4444
return "vllm.attention.backends.cpu_mla.CPUMLABackend"
4545
logger.info("Using Torch SDPA backend.")
4646
if use_v1:
47-
return "vllm.attention.backends.torch_sdpa.TorchSDPABackendV1"
47+
return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend"
4848
else:
4949
return "vllm.attention.backends.torch_sdpa.TorchSDPABackend"
5050

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import numpy as np
3+
import torch
4+
5+
from vllm.attention.backends.abstract import AttentionMetadata
6+
from vllm.attention.backends.torch_sdpa import (TorchSDPABackendImpl,
7+
TorchSDPAMetadata)
8+
from vllm.attention.backends.utils import CommonAttentionState
9+
from vllm.attention.ops.ipex_attn import PagedAttention
10+
from vllm.v1.core.sched.output import SchedulerOutput
11+
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
12+
from vllm.v1.worker.gpu_input_batch import InputBatch
13+
14+
15+
class TorchSDPABackend:
16+
accept_output_buffer: bool = False
17+
18+
@staticmethod
19+
def get_name() -> str:
20+
return "TORCH_SDPA_VLLM_V1"
21+
22+
@staticmethod
23+
def get_impl_cls() -> type["TorchSDPABackendImpl"]:
24+
return TorchSDPABackendImpl
25+
26+
@staticmethod
27+
def get_metadata_cls() -> type["AttentionMetadata"]:
28+
return TorchSDPAMetadata
29+
30+
@staticmethod
31+
def get_state_cls() -> type["CommonAttentionState"]:
32+
return CommonAttentionState
33+
34+
@staticmethod
35+
def get_builder_cls() -> type["TorchSDPAMetadataBuilderV1"]:
36+
return TorchSDPAMetadataBuilderV1
37+
38+
@staticmethod
39+
def get_kv_cache_shape(
40+
num_blocks: int,
41+
block_size: int,
42+
num_kv_heads: int,
43+
head_size: int,
44+
) -> tuple[int, ...]:
45+
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
46+
num_kv_heads, head_size)
47+
48+
@staticmethod
49+
def swap_blocks(
50+
src_kv_cache: torch.Tensor,
51+
dst_kv_cache: torch.Tensor,
52+
src_to_dst: torch.Tensor,
53+
) -> None:
54+
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
55+
56+
@staticmethod
57+
def copy_blocks(
58+
kv_caches: list[torch.Tensor],
59+
src_to_dists: torch.Tensor,
60+
) -> None:
61+
PagedAttention.copy_blocks(kv_caches, src_to_dists)
62+
63+
@staticmethod
64+
def use_cascade_attention(*args, **kwargs) -> bool:
65+
return False
66+
67+
68+
class TorchSDPAMetadataBuilderV1:
69+
70+
def __init__(self, runner: CPUModelRunner) -> None:
71+
self.runner = runner
72+
73+
# For reorder
74+
self.reorder_prompt_req_index_list = np.empty(self.runner.max_num_reqs,
75+
dtype=np.int64)
76+
self.reorder_decode_req_index_list = np.empty(self.runner.max_num_reqs,
77+
dtype=np.int64)
78+
self.num_prompt_req: int = 0
79+
80+
def reorder_batch(self, input_batch: InputBatch,
81+
scheduler_output: SchedulerOutput) -> bool:
82+
prompt_list_idx = 0
83+
decode_list_idx = 0
84+
for req_index in range(input_batch.num_reqs):
85+
if input_batch.num_computed_tokens_cpu[
86+
req_index] < input_batch.num_prompt_tokens[req_index]:
87+
# prompt stage
88+
self.reorder_prompt_req_index_list[prompt_list_idx] = req_index
89+
prompt_list_idx += 1
90+
else:
91+
# decode stage
92+
self.reorder_decode_req_index_list[decode_list_idx] = req_index
93+
decode_list_idx += 1
94+
assert decode_list_idx + prompt_list_idx == input_batch.num_reqs
95+
96+
# Update prompt requests number
97+
self.num_prompt_req = prompt_list_idx
98+
99+
reorder_req_num = 0
100+
for req_index in range(decode_list_idx):
101+
if self.reorder_decode_req_index_list[req_index] < prompt_list_idx:
102+
reorder_req_num += 1
103+
else:
104+
break
105+
106+
if reorder_req_num == 0:
107+
return False
108+
109+
reorder_prompt_list = (
110+
self.reorder_prompt_req_index_list[:prompt_list_idx]
111+
[-reorder_req_num:])
112+
reorder_decode_list = (
113+
self.reorder_decode_req_index_list[:decode_list_idx]
114+
[:reorder_req_num])
115+
assert reorder_decode_list.size == reorder_prompt_list.size
116+
117+
for idx in range(reorder_req_num):
118+
prompt_req_index = reorder_prompt_list[idx].item()
119+
decode_req_index = reorder_decode_list[idx].item()
120+
input_batch.swap_states(prompt_req_index, decode_req_index)
121+
122+
return True
123+
124+
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
125+
common_prefix_len: int):
126+
runner = self.runner
127+
seq_lens_np = runner.seq_lens_np[:num_reqs]
128+
num_prompt_req = self.num_prompt_req
129+
max_prefill_seq_len = seq_lens_np[:num_prompt_req].max().item(
130+
) if num_prompt_req > 0 else 0
131+
max_decode_seq_len = seq_lens_np[num_prompt_req:num_reqs].max().item(
132+
) if num_prompt_req < num_reqs else 0
133+
runner.seq_start_loc_np[0] = 0
134+
np.cumsum(seq_lens_np, out=runner.seq_start_loc_np[1:num_reqs + 1])
135+
num_prefill_tokens = runner.query_start_loc_np[num_prompt_req].item()
136+
num_decode_tokens = runner.query_start_loc_np[num_reqs].item(
137+
) - num_prefill_tokens
138+
slot_mapping = runner.slot_mapping_cpu[:num_actual_tokens].long()
139+
block_table_tensor = runner.input_batch.block_table.get_device_tensor()
140+
attn_metadata = TorchSDPAMetadata(
141+
num_prefills=num_prompt_req,
142+
num_prefill_tokens=num_prefill_tokens,
143+
num_decode_tokens=num_decode_tokens,
144+
slot_mapping=slot_mapping,
145+
seq_lens_tensor=runner.
146+
seq_lens_cpu[num_prompt_req:num_reqs], # decode
147+
max_decode_seq_len=max_decode_seq_len, # decode
148+
block_tables=block_table_tensor[num_prompt_req:num_reqs], # decode
149+
chunked_prefill=True,
150+
max_query_len=max_query_len,
151+
max_kv_len=max_prefill_seq_len,
152+
prefill_query_start_loc=runner.
153+
query_start_loc_cpu[:num_prompt_req + 1], # prefill
154+
kv_start_loc=runner.seq_start_loc_cpu[:num_prompt_req +
155+
1], # prefill
156+
prefill_block_tables=block_table_tensor[:
157+
num_prompt_req], # prefill
158+
query_start_loc=runner.query_start_loc_cpu[:num_reqs +
159+
1], # for logits index
160+
multi_modal_placeholder_index_maps=None,
161+
enable_kv_scales_calculation=False,
162+
)
163+
164+
return attn_metadata

0 commit comments

Comments
 (0)