|  | 
|  | 1 | +# SPDX-License-Identifier: Apache-2.0 | 
|  | 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | 
|  | 3 | +import functools | 
|  | 4 | +from copy import copy | 
|  | 5 | +from typing import Optional | 
|  | 6 | + | 
|  | 7 | +import torch | 
|  | 8 | +from transformers import CacheConfig | 
|  | 9 | + | 
|  | 10 | +from vllm import envs | 
|  | 11 | +from vllm.attention.backends.abstract import (AttentionBackend, | 
|  | 12 | +                                              AttentionMetadata, AttentionType) | 
|  | 13 | +from vllm.attention.layer import Attention | 
|  | 14 | +from vllm.attention.selector import get_attn_backend | 
|  | 15 | +from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, | 
|  | 16 | +                                              subclass_attention_backend) | 
|  | 17 | + | 
|  | 18 | + | 
|  | 19 | +@functools.lru_cache | 
|  | 20 | +def create_encoder_only_attention_backend( | 
|  | 21 | +    underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]: | 
|  | 22 | +    prefix = "EncoderOnlyAttention_" | 
|  | 23 | +    underlying_builder = underlying_attn_backend.get_builder_cls() | 
|  | 24 | + | 
|  | 25 | +    class EncoderOnlyAttentionBuilder(underlying_builder):  # type: ignore | 
|  | 26 | + | 
|  | 27 | +        def build(self, | 
|  | 28 | +                  common_prefix_len: int, | 
|  | 29 | +                  common_attn_metadata: CommonAttentionMetadata, | 
|  | 30 | +                  fast_build: bool = False) -> AttentionMetadata: | 
|  | 31 | +            new_common_attn_metadata = copy(common_attn_metadata) | 
|  | 32 | +            new_common_attn_metadata.causal = False | 
|  | 33 | +            return super().build(common_prefix_len, new_common_attn_metadata, | 
|  | 34 | +                                 fast_build) | 
|  | 35 | + | 
|  | 36 | +    attn_backend = subclass_attention_backend( | 
|  | 37 | +        name_prefix=prefix, | 
|  | 38 | +        attention_backend_cls=underlying_attn_backend, | 
|  | 39 | +        builder_cls=EncoderOnlyAttentionBuilder) | 
|  | 40 | + | 
|  | 41 | +    return attn_backend | 
|  | 42 | + | 
|  | 43 | + | 
|  | 44 | +class EncoderOnlyAttention(Attention): | 
|  | 45 | +    """ | 
|  | 46 | +    Encoder attention is a special case that doesn't need a KV Cache. | 
|  | 47 | +    """ | 
|  | 48 | + | 
|  | 49 | +    def __init__(self, | 
|  | 50 | +                 num_heads: int, | 
|  | 51 | +                 head_size: int, | 
|  | 52 | +                 scale: float, | 
|  | 53 | +                 cache_config: Optional[CacheConfig] = None, | 
|  | 54 | +                 attn_type: Optional[str] = None, | 
|  | 55 | +                 **kwargs): | 
|  | 56 | +        dtype = torch.get_default_dtype() | 
|  | 57 | + | 
|  | 58 | +        if cache_config is not None: | 
|  | 59 | +            kv_cache_dtype = cache_config.cache_dtype | 
|  | 60 | +            block_size = cache_config.block_size | 
|  | 61 | +        else: | 
|  | 62 | +            kv_cache_dtype = "auto" | 
|  | 63 | +            block_size = 16 | 
|  | 64 | + | 
|  | 65 | +        if envs.VLLM_USE_V1: | 
|  | 66 | +            underlying_attn_backend = get_attn_backend(head_size, dtype, | 
|  | 67 | +                                                       kv_cache_dtype, | 
|  | 68 | +                                                       block_size) | 
|  | 69 | + | 
|  | 70 | +            attn_backend = create_encoder_only_attention_backend( | 
|  | 71 | +                underlying_attn_backend) | 
|  | 72 | +        else: | 
|  | 73 | +            # in v0 encoder only attention is handled inside the backends | 
|  | 74 | +            attn_backend = None | 
|  | 75 | + | 
|  | 76 | +        if attn_type is not None: | 
|  | 77 | +            assert attn_type == AttentionType.ENCODER_ONLY, \ | 
|  | 78 | +                "EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY" | 
|  | 79 | + | 
|  | 80 | +        super().__init__(num_heads=num_heads, | 
|  | 81 | +                         head_size=head_size, | 
|  | 82 | +                         scale=scale, | 
|  | 83 | +                         cache_config=cache_config, | 
|  | 84 | +                         attn_backend=attn_backend, | 
|  | 85 | +                         attn_type=AttentionType.ENCODER_ONLY, | 
|  | 86 | +                         **kwargs) | 
0 commit comments