99from transformers import LlamaConfig
1010
1111from vllm .compilation .decorators import support_torch_compile
12- from vllm .config import VllmConfig
12+ from vllm .config import CacheConfig , VllmConfig , get_current_vllm_config
1313from vllm .logger import init_logger
1414from vllm .model_executor .layers .layernorm import RMSNorm
1515from vllm .model_executor .layers .linear import QKVParallelLinear
@@ -33,10 +33,14 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
3333 def __init__ (
3434 self ,
3535 config : LlamaConfig ,
36+ cache_config : Optional [CacheConfig ] = None ,
3637 quant_config : Optional [QuantizationConfig ] = None ,
3738 prefix : str = "" ,
3839 ) -> None :
39- super ().__init__ (config , quant_config = quant_config , prefix = prefix )
40+ super ().__init__ (config ,
41+ cache_config = cache_config ,
42+ quant_config = quant_config ,
43+ prefix = prefix )
4044
4145 # override qkv
4246 self .self_attn .qkv_proj = QKVParallelLinear (
@@ -114,6 +118,8 @@ def __init__(
114118 speculative_config .draft_model_config .hf_config
115119 self .vocab_size = self .config .vocab_size
116120
121+ current_vllm_config = get_current_vllm_config ()
122+
117123 self .embed_tokens = VocabParallelEmbedding (
118124 self .config .vocab_size ,
119125 self .config .hidden_size ,
@@ -123,6 +129,7 @@ def __init__(
123129 self .layers = nn .ModuleList ([
124130 LlamaDecoderLayer (
125131 config = self .config ,
132+ cache_config = current_vllm_config .cache_config ,
126133 prefix = maybe_prefix (prefix , f"layers.{ start_layer_id } " ),
127134 )
128135 ])
0 commit comments