Skip to content

Commit 4e2774b

Browse files
committed
refine APIs
Co-authored-by: yan <yan.ma@intel.com> Co-authored-by: mayuyuace <qiming1.zhang@intel.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
1 parent 12a4bd3 commit 4e2774b

File tree

2 files changed

+37
-7
lines changed

2 files changed

+37
-7
lines changed

vllm/v1/attention/backends/ipex_attn.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def __init__(
132132
else:
133133
self.sliding_window = (sliding_window - 1, 0)
134134
self.kv_cache_dtype = kv_cache_dtype
135+
self.use_irope = use_irope
135136
if logits_soft_cap is None:
136137
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
137138
logits_soft_cap = 0
@@ -204,19 +205,45 @@ def forward(
204205
layer._k_scale_float,
205206
layer._v_scale_float,
206207
)
208+
use_local_attn = \
209+
(self.use_irope and attn_metadata.local_attn_metadata is not None)
210+
211+
if use_local_attn:
212+
assert attn_metadata.local_attn_metadata is not None
213+
local_metadata = attn_metadata.local_attn_metadata
214+
cu_seqlens_q = local_metadata.local_query_start_loc
215+
sequesd_k = local_metadata.local_seqused_k
216+
max_seqlen_q = local_metadata.local_max_query_len
217+
max_seqlen_k = local_metadata.local_max_seq_len
218+
block_table = local_metadata.local_block_table
219+
else:
220+
cu_seqlens_q = attn_metadata.query_start_loc
221+
sequesd_k = attn_metadata.seq_lens
222+
max_seqlen_q = attn_metadata.max_query_len
223+
max_seqlen_k = attn_metadata.max_seq_len
224+
block_table = attn_metadata.block_table
225+
226+
if not hasattr(attn_metadata, "seq_start_loc"):
227+
cumsum = torch.cumsum(sequesd_k, dim=0)
228+
cu_seqlens_k = torch.cat([
229+
torch.tensor([0], device=sequesd_k.device, dtype=torch.int32),
230+
cumsum
231+
]).to(torch.int32)
232+
else:
233+
cu_seqlens_k = attn_metadata.seq_start_loc
207234

208235
ipex_ops.flash_attn_varlen_func(
209236
output[:num_actual_tokens],
210237
query[:num_actual_tokens],
211238
key_cache,
212239
value_cache,
213-
attn_metadata.query_start_loc,
214-
attn_metadata.seq_start_loc,
215-
attn_metadata.max_query_len,
216-
attn_metadata.max_seq_len,
240+
cu_seqlens_q,
241+
cu_seqlens_k,
242+
max_seqlen_q,
243+
max_seqlen_k,
217244
self.scale,
218245
is_casual=True,
219-
block_table=attn_metadata.block_table,
246+
block_table=block_table,
220247
alibi_slopes=self.alibi_slopes,
221248
)
222249
return output

vllm/v1/worker/xpu_model_runner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
from typing import TYPE_CHECKING
2+
from typing import TYPE_CHECKING, Any, Optional
33

44
import numpy as np
55
import torch
66

77
from vllm.config import VllmConfig
88
from vllm.logger import init_logger
9+
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
910
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
1011

1112
if TYPE_CHECKING:
@@ -38,7 +39,9 @@ def _init_device_properties(self) -> None:
3839
def _sync_device(self) -> None:
3940
torch.xpu.synchronize()
4041

41-
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
42+
def _prepare_inputs(
43+
self, scheduler_output: "SchedulerOutput"
44+
) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata]]:
4245
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
4346
assert total_num_scheduled_tokens > 0
4447
num_reqs = self.input_batch.num_reqs

0 commit comments

Comments
 (0)