Skip to content

Commit 9426843

Browse files
committed
Enable deepseek-r1 with xpu-kernels. (vllm-project#6)
* [kernel][DS-R1][linear] use default Fp8LinearMethod/Fp8MoEMethod Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> * [kernel][DS-R1][Attention] enable Triton MLA attention Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> * enable MHA for deepseek, need padding head_size to make flash attn kernel happy Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> * not break fp8 path Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> --------- Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
1 parent 6ee2247 commit 9426843

File tree

9 files changed

+63
-7
lines changed

9 files changed

+63
-7
lines changed

vllm/_ipex_ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def flash_attn_varlen_func(
6262
k_descale=None,
6363
v_descale=None,
6464
num_splits=0,
65+
return_softmax_lse: bool | None = False,
6566
s_aux: torch.Tensor | None = None,
6667
):
6768
if out is None:
@@ -97,6 +98,7 @@ def flash_attn_varlen_func(
9798
window_size=real_window_size,
9899
# alibi_slopes = alibi_slopes,
99100
# softcap=softcap,
101+
return_softmax_lse=return_softmax_lse,
100102
)
101103

102104
@staticmethod

vllm/attention/ops/triton_flash_attention.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,19 @@ def get_rdna_autotune_configs():
341341
], ["IS_CAUSAL", "dropout_p", "BLOCK_DMODEL", "USE_FP8"]
342342

343343

344+
def get_xpu_autotune_configs():
345+
return [
346+
triton.Config(
347+
{"BLOCK_M": 32, "BLOCK_N": 32, "PRE_LOAD_V": False},
348+
num_stages=1,
349+
num_warps=2,
350+
),
351+
], ["IS_CAUSAL", "dropout_p", "BLOCK_DMODEL", "USE_FP8"]
352+
353+
344354
def get_autotune_configs():
355+
if current_platform.is_xpu():
356+
return get_xpu_autotune_configs()
345357
if on_gfx1x():
346358
return get_rdna_autotune_configs()
347359
else:

vllm/envs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@
224224
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
225225
VLLM_FLATTEN_LOGPROBS: bool = False
226226
VLLM_XPU_USE_W8A8_GEMM: bool = False
227+
VLLM_XPU_ATTN_HEAD_SIZE_PAD: bool = False
227228

228229

229230
def get_default_cache_root():
@@ -1486,6 +1487,9 @@ def get_vllm_port() -> int | None:
14861487
"VLLM_XPU_USE_W8A8_GEMM": lambda: bool(
14871488
int(os.getenv("VLLM_XPU_USE_W8A8_GEMM", "0"))
14881489
),
1490+
"VLLM_XPU_ATTN_HEAD_SIZE_PAD": lambda: bool(
1491+
int(os.getenv("VLLM_XPU_ATTN_HEAD_SIZE_PAD", "0"))
1492+
),
14891493
}
14901494

14911495
# --8<-- [end:env-vars-definition]
@@ -1613,6 +1617,7 @@ def compute_hash() -> str:
16131617
"VLLM_DEEPEP_HIGH_THROUGHPUT_FORCE_INTRA_NODE",
16141618
"VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL",
16151619
"VLLM_XPU_USE_W8A8_GEMM",
1620+
"VLLM_XPU_ATTN_HEAD_SIZE_PAD",
16161621
]
16171622
for key in environment_variables_to_hash:
16181623
# if this goes out of sync with environment_variables,

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
153153
not current_platform.has_device_capability(89)
154154
or envs.VLLM_TEST_FORCE_FP8_MARLIN
155155
)
156-
if current_platform.is_rocm():
156+
if current_platform.is_rocm() or current_platform.is_xpu():
157157
use_marlin = False
158158
if use_marlin:
159159
logger.info_once("Using Marlin backend for FP8 MoE")
@@ -284,7 +284,9 @@ def get_quant_method(
284284
) -> Optional["QuantizeMethodBase"]:
285285
from vllm.attention.layer import Attention # Avoid circular import
286286

287-
if current_platform.is_xpu():
287+
# for non-block quant on xpu, we use the xpu fp8 method,
288+
# otherwise use triton
289+
if current_platform.is_xpu() and self.weight_block_size is None:
288290
return self.get_xpu_quant_method(layer, prefix)
289291
if isinstance(layer, LinearBase):
290292
if is_layer_skipped(

vllm/platforms/xpu.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ def get_attn_backend_cls(
7373
raise NotImplementedError("Sparse Attention is not supported on XPU.")
7474
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
7575
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
76+
TRITON_ATTN_MLA = "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend" # noqa: E501
77+
if use_mla:
78+
logger.info_once("Using Triton MLA backend on V1 engine.")
79+
return TRITON_ATTN_MLA
7680
if selected_backend == _Backend.TRITON_ATTN:
7781
logger.info_once("Using Triton backend.")
7882
return TRITON_ATTN

vllm/v1/attention/backends/flash_attn.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,31 @@ def forward(
550550
return output.fill_(0)
551551

552552
attn_type = self.attn_type
553+
output_may_pad = output # default
554+
555+
if envs.VLLM_XPU_ATTN_HEAD_SIZE_PAD:
556+
logger.warning_once(
557+
"VLLM_XPU_ATTN_HEAD_SIZE_PAD is enabled. "
558+
"Padding head size to 256 for FlashAttention."
559+
)
560+
# due to attention head size limitations in current flash attention
561+
# kernel(which support 64/128/256 only), we will pad the head size
562+
# to 256 for deepseek model.
563+
orig_head_size = query.shape[-1]
564+
new_shape = query.shape[:-1] + (256,)
565+
566+
query_pad = query.new_zeros(new_shape)
567+
query_pad[..., : query.shape[-1]] = query
568+
key_pad = key.new_zeros(new_shape)
569+
key_pad[..., : key.shape[-1]] = key
570+
value_pad = value.new_zeros(new_shape)
571+
value_pad[..., : value.shape[-1]] = value
572+
# for output, it's inplace?
573+
output_may_pad = output.new_zeros(new_shape)
574+
575+
query = query_pad
576+
key = key_pad
577+
value = value_pad
553578

554579
# IMPORTANT!
555580
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
@@ -641,7 +666,7 @@ def forward(
641666
q=query[:num_actual_tokens],
642667
k=key_cache,
643668
v=value_cache,
644-
out=output[:num_actual_tokens],
669+
out=output_may_pad[:num_actual_tokens],
645670
cu_seqlens_q=cu_seqlens_q,
646671
max_seqlen_q=max_seqlen_q,
647672
seqused_k=seqused_k,
@@ -660,7 +685,12 @@ def forward(
660685
num_splits=attn_metadata.max_num_splits,
661686
s_aux=self.sinks,
662687
)
663-
return output
688+
if envs.VLLM_XPU_ATTN_HEAD_SIZE_PAD:
689+
# it's inplace, we should not replace.
690+
output[:num_actual_tokens] = output_may_pad[
691+
:num_actual_tokens, :, :orig_head_size
692+
]
693+
return output
664694

665695
# Cascade attention (rare case).
666696
cascade_attention(

vllm/v1/attention/backends/mla/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ class QueryLenSupport(Enum):
251251

252252

253253
try:
254-
from vllm.vllm_flash_attn import flash_attn_varlen_func
254+
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
255255

256256
is_vllm_fa = True
257257
except ImportError:

vllm/v1/attention/backends/mla/flashattn_mla.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
)
1515
from vllm.attention.utils.fa_utils import (
1616
flash_attn_supports_mla,
17+
flash_attn_varlen_func,
1718
get_flash_attn_version,
19+
get_scheduler_metadata,
1820
)
1921
from vllm.config import VllmConfig
2022
from vllm.logger import init_logger
@@ -31,7 +33,6 @@
3133
)
3234
from vllm.v1.attention.backends.utils import AttentionCGSupport
3335
from vllm.v1.kv_cache_interface import AttentionSpec
34-
from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata
3536

3637
logger = init_logger(__name__)
3738

vllm/v1/attention/backends/mla/triton_mla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def _flash_attn_varlen_diff_headdims(
121121
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
122122
):
123123
if (
124-
current_platform.is_rocm()
124+
(current_platform.is_rocm() or current_platform.is_xpu())
125125
and self.use_triton_flash_attn
126126
and not return_softmax_lse
127127
):

0 commit comments

Comments
 (0)