From 16b24e7dcd8da5f2ac50f149daa77288fa8c14d7 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Sun, 13 Oct 2024 19:02:11 -0400 Subject: [PATCH] [Bugfix] Bandaid fix for speculative decoding tests (#9327) --- vllm/worker/model_runner.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9db3261b8ac36..f88b1d84fbcd1 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -17,6 +17,7 @@ import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention.backends.abstract import AttentionState +from vllm.attention.backends.utils import CommonAttentionState from vllm.compilation.compile_context import set_compile_context from vllm.compilation.levels import CompilationLevel from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, @@ -1001,6 +1002,17 @@ def __init__( self.graph_block_tables = np.zeros( (self.max_batchsize_to_capture, self.get_max_block_per_batch()), dtype=np.int32) + + # Attention-free but stateful models like Mamba need a placeholder attn + # backend, as the attention metadata is needed to manage internal state. + # However we must bypass attention selection altogether for some models + # used for speculative decoding to avoid a divide-by-zero in + # model_config.get_head_size() + num_attn_heads = self.model_config.get_num_attention_heads( + self.parallel_config) + needs_attn_backend = (num_attn_heads != 0 + or self.model_config.is_attention_free) + self.attn_backend = get_attn_backend( self.model_config.get_head_size(), self.model_config.get_sliding_window(), @@ -1008,9 +1020,12 @@ def __init__( self.kv_cache_dtype, self.block_size, self.model_config.is_attention_free, - ) - self.attn_state = self.attn_backend.get_state_cls()( - weakref.proxy(self)) + ) if needs_attn_backend else None + if self.attn_backend: + self.attn_state = self.attn_backend.get_state_cls()( + weakref.proxy(self)) + else: + self.attn_state = CommonAttentionState(weakref.proxy(self)) # Multi-modal data support self.input_registry = input_registry