|
4 | 4 | from dataclasses import dataclass |
5 | 5 | from typing import Any, Dict, List, Optional, Tuple, Type |
6 | 6 |
|
7 | | -import numpy as np |
8 | 7 | import torch |
9 | 8 | from torch.nn.functional import scaled_dot_product_attention |
10 | 9 |
|
|
22 | 21 | from vllm.attention.ops.paged_attn import PagedAttentionMetadata |
23 | 22 | from vllm.logger import init_logger |
24 | 23 | from vllm.utils import make_tensor_with_pad |
25 | | -from vllm.v1.core.sched.output import SchedulerOutput |
26 | | -from vllm.v1.worker.cpu_model_runner import CPUModelRunner |
27 | | -from vllm.v1.worker.gpu_input_batch import InputBatch |
28 | 24 | from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder |
29 | 25 |
|
30 | 26 | logger = init_logger(__name__) |
@@ -78,59 +74,6 @@ def copy_blocks( |
78 | 74 | PagedAttention.copy_blocks(kv_caches, src_to_dists) |
79 | 75 |
|
80 | 76 |
|
81 | | -class TorchSDPABackendV1: |
82 | | - accept_output_buffer: bool = False |
83 | | - |
84 | | - @staticmethod |
85 | | - def get_name() -> str: |
86 | | - return "TORCH_SDPA" |
87 | | - |
88 | | - @staticmethod |
89 | | - def get_impl_cls() -> Type["TorchSDPABackendImpl"]: |
90 | | - return TorchSDPABackendImpl |
91 | | - |
92 | | - @staticmethod |
93 | | - def get_metadata_cls() -> Type["AttentionMetadata"]: |
94 | | - return TorchSDPAMetadata |
95 | | - |
96 | | - @staticmethod |
97 | | - def get_state_cls() -> Type["CommonAttentionState"]: |
98 | | - return CommonAttentionState |
99 | | - |
100 | | - @staticmethod |
101 | | - def get_builder_cls() -> Type["TorchSDPAMetadataBuilderV1"]: |
102 | | - return TorchSDPAMetadataBuilderV1 |
103 | | - |
104 | | - @staticmethod |
105 | | - def get_kv_cache_shape( |
106 | | - num_blocks: int, |
107 | | - block_size: int, |
108 | | - num_kv_heads: int, |
109 | | - head_size: int, |
110 | | - ) -> Tuple[int, ...]: |
111 | | - return PagedAttention.get_kv_cache_shape(num_blocks, block_size, |
112 | | - num_kv_heads, head_size) |
113 | | - |
114 | | - @staticmethod |
115 | | - def swap_blocks( |
116 | | - src_kv_cache: torch.Tensor, |
117 | | - dst_kv_cache: torch.Tensor, |
118 | | - src_to_dst: torch.Tensor, |
119 | | - ) -> None: |
120 | | - PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) |
121 | | - |
122 | | - @staticmethod |
123 | | - def copy_blocks( |
124 | | - kv_caches: List[torch.Tensor], |
125 | | - src_to_dists: torch.Tensor, |
126 | | - ) -> None: |
127 | | - PagedAttention.copy_blocks(kv_caches, src_to_dists) |
128 | | - |
129 | | - @staticmethod |
130 | | - def use_cascade_attention(*args, **kwargs) -> bool: |
131 | | - return False |
132 | | - |
133 | | - |
134 | 77 | @dataclass |
135 | 78 | class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): |
136 | 79 | """Metadata for TorchSDPABackend. |
@@ -343,105 +286,6 @@ def get_seq_len_block_table_args( |
343 | 286 | raise AttributeError(f"Invalid attention type {str(attn_type)}") |
344 | 287 |
|
345 | 288 |
|
346 | | -class TorchSDPAMetadataBuilderV1: |
347 | | - |
348 | | - def __init__(self, runner: CPUModelRunner) -> None: |
349 | | - self.runner = runner |
350 | | - |
351 | | - # For reorder |
352 | | - self.reorder_prompt_req_index_list = np.empty(self.runner.max_num_reqs, |
353 | | - dtype=np.int64) |
354 | | - self.reorder_decode_req_index_list = np.empty(self.runner.max_num_reqs, |
355 | | - dtype=np.int64) |
356 | | - self.num_prompt_req: int = 0 |
357 | | - |
358 | | - def reorder_batch(self, input_batch: InputBatch, |
359 | | - scheduler_output: SchedulerOutput) -> bool: |
360 | | - prompt_list_idx = 0 |
361 | | - decode_list_idx = 0 |
362 | | - for req_index in range(input_batch.num_reqs): |
363 | | - if input_batch.num_computed_tokens_cpu[ |
364 | | - req_index] < input_batch.num_prompt_tokens[req_index]: |
365 | | - # prompt stage |
366 | | - self.reorder_prompt_req_index_list[prompt_list_idx] = req_index |
367 | | - prompt_list_idx += 1 |
368 | | - else: |
369 | | - # decode stage |
370 | | - self.reorder_decode_req_index_list[decode_list_idx] = req_index |
371 | | - decode_list_idx += 1 |
372 | | - assert decode_list_idx + prompt_list_idx == input_batch.num_reqs |
373 | | - |
374 | | - # Update prompt requests number |
375 | | - self.num_prompt_req = prompt_list_idx |
376 | | - |
377 | | - reorder_req_num = 0 |
378 | | - for req_index in range(decode_list_idx): |
379 | | - if self.reorder_decode_req_index_list[req_index] < prompt_list_idx: |
380 | | - reorder_req_num += 1 |
381 | | - else: |
382 | | - break |
383 | | - |
384 | | - if reorder_req_num == 0: |
385 | | - return False |
386 | | - |
387 | | - reorder_prompt_list = ( |
388 | | - self.reorder_prompt_req_index_list[:prompt_list_idx] |
389 | | - [-reorder_req_num:]) |
390 | | - reorder_decode_list = ( |
391 | | - self.reorder_decode_req_index_list[:decode_list_idx] |
392 | | - [:reorder_req_num]) |
393 | | - assert reorder_decode_list.size == reorder_prompt_list.size |
394 | | - |
395 | | - for idx in range(reorder_req_num): |
396 | | - prompt_req_index = reorder_prompt_list[idx].item() |
397 | | - decode_req_index = reorder_decode_list[idx].item() |
398 | | - input_batch.swap_states(prompt_req_index, decode_req_index) |
399 | | - |
400 | | - return True |
401 | | - |
402 | | - def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, |
403 | | - common_prefix_len: int): |
404 | | - runner = self.runner |
405 | | - seq_lens_np = runner.seq_lens_np[:num_reqs] |
406 | | - num_prompt_req = self.num_prompt_req |
407 | | - max_prefill_seq_len = seq_lens_np[:num_prompt_req].max().item( |
408 | | - ) if num_prompt_req > 0 else 0 |
409 | | - max_decode_seq_len = seq_lens_np[num_prompt_req:num_reqs].max().item( |
410 | | - ) if num_prompt_req < num_reqs else 0 |
411 | | - runner.seq_start_loc_np[0] = 0 |
412 | | - np.cumsum(seq_lens_np, out=runner.seq_start_loc_np[1:num_reqs + 1]) |
413 | | - num_prefill_tokens = runner.query_start_loc_np[num_prompt_req].item() |
414 | | - num_decode_tokens = runner.query_start_loc_np[num_reqs].item( |
415 | | - ) - num_prefill_tokens |
416 | | - slot_mapping = runner.slot_mapping_cpu[:num_actual_tokens].long() |
417 | | - block_table_tensor = runner.input_batch.block_table.get_device_tensor() |
418 | | - attn_metadata = TorchSDPAMetadata( |
419 | | - num_prefills=num_prompt_req, |
420 | | - num_prefill_tokens=num_prefill_tokens, |
421 | | - num_decode_tokens=num_decode_tokens, |
422 | | - slot_mapping=slot_mapping, |
423 | | - seq_lens_tensor=runner. |
424 | | - seq_lens_cpu[num_prompt_req:num_reqs], # decode |
425 | | - max_decode_seq_len=max_decode_seq_len, # decode |
426 | | - block_tables=block_table_tensor[num_prompt_req:num_reqs], # decode |
427 | | - chunked_prefill=True, |
428 | | - max_query_len=max_query_len, |
429 | | - max_kv_len=max_prefill_seq_len, |
430 | | - prefill_query_start_loc=runner. |
431 | | - query_start_loc_cpu[:num_prompt_req + 1], # prefill |
432 | | - kv_start_loc=runner.seq_start_loc_cpu[:num_prompt_req + |
433 | | - 1], # prefill |
434 | | - prefill_block_tables=block_table_tensor[: |
435 | | - num_prompt_req], # prefill |
436 | | - query_start_loc=runner.query_start_loc_cpu[:num_reqs + |
437 | | - 1], # for logits index |
438 | | - multi_modal_placeholder_index_maps=None, |
439 | | - enable_kv_scales_calculation=False, |
440 | | - ) |
441 | | - |
442 | | - return attn_metadata |
443 | | - |
444 | | - |
445 | 289 | class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]): |
446 | 290 |
|
447 | 291 | def __init__(self, input_builder: ModelInputForCPUBuilder) -> None: |
|
0 commit comments