@@ -513,10 +513,6 @@ def make_backend(backend_name: str) -> AttentionBackend:
513513 Construct the backend instance determined by the backend_name string
514514 argument.
515515
516- "XFORMERS" -> construct xformers backend
517-
518- TODO: other backends
519-
520516 Note: at time of writing the Attention wrapper automatically selects
521517 its own backend for Attention.forward(); so the backend instance which
522518 you generate with this function is not meant to be used for *running*
@@ -528,18 +524,68 @@ def make_backend(backend_name: str) -> AttentionBackend:
528524
529525 * Backend instance
530526 '''
531- if backend_name == STR_XFORMERS_ATTN_VAL :
532- # NOTE: xFormers backend cannot be imported for CPU and AMD GPUs.
533- from vllm . attention . backends . xformers import XFormersBackend
534- return XFormersBackend ()
535- elif backend_name == STR_FLASH_ATTN_VAL :
536- from vllm .attention .backends .flash_attn import FlashAttentionBackend
527+ if backend_name in ( STR_XFORMERS_ATTN_VAL , "XFORMERS_VLLM_V1" ) :
528+ from vllm . v1 . attention . backends . xformers import (
529+ XFormersAttentionBackend )
530+ return XFormersAttentionBackend ()
531+ if backend_name in ( STR_FLASH_ATTN_VAL , "FLASH_ATTN_VLLM_V1" ) :
532+ from vllm .v1 . attention .backends .flash_attn import FlashAttentionBackend
537533 return FlashAttentionBackend ()
534+ if backend_name == "TRITON_ATTN_VLLM_V1" :
535+ from vllm .v1 .attention .backends .triton_attn import (
536+ TritonAttentionBackend )
537+ return TritonAttentionBackend ()
538+ if backend_name == "FLEX_ATTENTION" :
539+ from vllm .v1 .attention .backends .flex_attention import (
540+ FlexAttentionBackend )
541+ return FlexAttentionBackend ()
542+ if backend_name in ("TORCH_SDPA" , "TORCH_SDPA_VLLM_V1" ):
543+ from vllm .v1 .attention .backends .cpu_attn import TorchSDPABackend
544+ return TorchSDPABackend ()
545+ if backend_name == "FLASHINFER" :
546+ from vllm .v1 .attention .backends .flashinfer import FlashInferBackend
547+ return FlashInferBackend ()
538548
539549 raise AssertionError (
540550 f"Unrecognized backend_name { backend_name } for unit test" )
541551
542552
553+ def make_alibi_bias (
554+ alibi_slopes : torch .Tensor ,
555+ num_kv_heads : int ,
556+ dtype : torch .dtype ,
557+ seq_lens : list [int ],
558+ ) -> list [Any ]:
559+ """Create ALiBi biases compatible with xFormers attention tests."""
560+ from xformers .ops .fmha .attn_bias import LowerTriangularMaskWithTensorBias
561+
562+ if alibi_slopes is None :
563+ return [None for _ in seq_lens ]
564+
565+ attn_biases : list [Any ] = []
566+ num_heads = alibi_slopes .shape [0 ]
567+ assert num_heads >= num_kv_heads , (
568+ "ALiBi slopes expect at least as many heads as KV heads" )
569+
570+ for seq_len in seq_lens :
571+ bias = torch .arange (seq_len , dtype = dtype , device = alibi_slopes .device )
572+ bias = bias [None , :] - bias [:, None ]
573+
574+ padded_len = (seq_len + 7 ) // 8 * 8
575+ bias_tensor = torch .empty (
576+ 1 ,
577+ num_heads ,
578+ seq_len ,
579+ padded_len ,
580+ device = alibi_slopes .device ,
581+ dtype = dtype ,
582+ )[:, :, :, :seq_len ].copy_ (bias )
583+ bias_tensor .mul_ (alibi_slopes [:, None , None ])
584+ attn_biases .append (LowerTriangularMaskWithTensorBias (bias_tensor ))
585+
586+ return attn_biases
587+
588+
543589def _make_metadata_tensors (
544590 seq_lens : Optional [list [int ]],
545591 context_lens : Optional [list [int ]],
0 commit comments