@@ -106,6 +106,7 @@ def __init__(
106106 cache_config = cache_config ,
107107 quant_config = quant_config ,
108108 prefix = f"{ prefix } .attn" ,
109+ attn_type = self .attn_type ,
109110 )
110111
111112 def _init_qkv (
@@ -134,12 +135,7 @@ def forward(
134135 qkv , _ = self .qkv_proj (hidden_states )
135136 q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
136137
137- attn_output = self .attn (q ,
138- k ,
139- v ,
140- kv_cache ,
141- attn_metadata ,
142- attn_type = self .attn_type )
138+ attn_output = self .attn (q , k , v , kv_cache , attn_metadata )
143139
144140 output , _ = self .out_proj (attn_output )
145141
@@ -164,6 +160,7 @@ def __init__(
164160 cache_config = cache_config ,
165161 quant_config = quant_config ,
166162 prefix = prefix ,
163+ attn_type = AttentionType .ENCODER_DECODER ,
167164 )
168165
169166 def _init_qkv (
@@ -207,12 +204,13 @@ def forward(
207204 else :
208205 k = v = None
209206
210- attn_output = self .attn (q ,
211- k ,
212- v ,
213- kv_cache ,
214- attn_metadata ,
215- attn_type = AttentionType .ENCODER_DECODER )
207+ attn_output = self .attn (
208+ q ,
209+ k ,
210+ v ,
211+ kv_cache ,
212+ attn_metadata ,
213+ )
216214
217215 output , _ = self .out_proj (attn_output )
218216
@@ -734,4 +732,4 @@ def load_weights(self, weights: Iterable[Tuple[str,
734732 loaded_weights = [(name , loaded_weight )
735733 for name , loaded_weight in weights ]
736734 mapper = WeightsMapper ({".fc1." : ".mlp.fc1." , ".fc2." : ".mlp.fc2." })
737- return loader .load_weights (loaded_weights , mapper = mapper )
735+ return loader .load_weights (loaded_weights , mapper = mapper )
0 commit comments