@@ -123,6 +123,69 @@ def maybe_get_vit_flash_attn_backend(
123123 return attn_backend , flash_attn_varlen_func
124124
125125
126+ def _init_kv_cache_quant (
127+ layer : nn .Module ,
128+ quant_config : QuantizationConfig | None ,
129+ prefix : str ,
130+ kv_cache_dtype : str ,
131+ calculate_kv_scales : bool ,
132+ ) -> None :
133+ """Initializes KV cache scaling factors and quantization method.
134+
135+ This helper function sets up the KV cache quantization attributes that are
136+ shared between Attention and MLAAttention layers. It initializes scale
137+ tensors for query, key, value, and probability, and configures the
138+ quantization method if applicable.
139+
140+ Args:
141+ layer: The attention layer instance to initialize.
142+ quant_config: Optional quantization configuration.
143+ prefix: Layer name prefix for quantization method lookup.
144+ kv_cache_dtype: The KV cache data type string.
145+ calculate_kv_scales: Whether to calculate KV scales dynamically.
146+ """
147+ # The default k/v_scale is set to 1.0. This is ignored
148+ # when kv-cache is not fp8, and should be used with
149+ # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
150+ # expect the pre-quantized k/v_scale to be loaded along
151+ # with the model weights.
152+ layer .kv_cache_dtype = kv_cache_dtype
153+ layer .calculate_kv_scales = calculate_kv_scales
154+ layer ._k_scale = torch .tensor (1.0 , dtype = torch .float32 )
155+ layer ._v_scale = torch .tensor (1.0 , dtype = torch .float32 )
156+ layer ._q_scale = torch .tensor (1.0 , dtype = torch .float32 )
157+ layer ._prob_scale = torch .tensor (1.0 , dtype = torch .float32 )
158+
159+ # We also keep q/k/v_scale on host (cpu) memory for attention
160+ # backends that require the scales to be on host instead of on device.
161+ # e.g. Flashinfer
162+ layer ._q_scale_float = 1.0
163+ layer ._k_scale_float = 1.0
164+ layer ._v_scale_float = 1.0
165+
166+ # The output scale on host memory. This should be the input scale of
167+ # the quant op after this attention layer.
168+ layer ._o_scale_float = None
169+
170+ quant_method = (
171+ quant_config .get_quant_method (layer , prefix = prefix ) if quant_config else None
172+ )
173+ if quant_method is not None and not isinstance (
174+ quant_method , UnquantizedLinearMethod
175+ ):
176+ assert isinstance (quant_method , BaseKVCacheMethod )
177+ # TODO (mgoin): kv cache dtype should be specified in the FP8
178+ # checkpoint config and become the "auto" behavior
179+ if kv_cache_dtype == "fp8_e5m2" :
180+ raise ValueError ("fp8_e5m2 kv-cache is not supported with fp8 checkpoints." )
181+ # If quantization is enabled, we make "k_scale" and "v_scale"
182+ # parameters so that it can be loaded from the model checkpoint.
183+ # The k/v_scale will then be converted back to native float32
184+ # values after weight loading.
185+ layer .quant_method = quant_method
186+ layer .quant_method .create_weights (layer )
187+
188+
126189class Attention (nn .Module , AttentionLayerBase ):
127190 """Attention layer.
128191
@@ -184,57 +247,17 @@ def __init__(
184247 f"num_heads ({ num_heads } ) is not divisible by num_kv_heads ({ num_kv_heads } )"
185248 )
186249
187- # The default k/v_scale is set to 1.0. This is ignored
188- # when kv-cache is not fp8, and should be used with
189- # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
190- # expect the pre-quantized k/v_scale to be loaded along
191- # with the model weights.
192- self .kv_cache_dtype = kv_cache_dtype
193- self .calculate_kv_scales = calculate_kv_scales
194- self ._k_scale = torch .tensor (1.0 , dtype = torch .float32 )
195- self ._v_scale = torch .tensor (1.0 , dtype = torch .float32 )
196- # FlashAttn doesn't support quantizing the kv-cache only
197- # but requires q to be quantized as well.
198- self ._q_scale = torch .tensor (1.0 , dtype = torch .float32 )
199- self ._prob_scale = torch .tensor (1.0 , dtype = torch .float32 )
200-
201- # We also keep q/k/v_scale on host (cpu) memory for attention
202- # backends that require the scales to be on host instead of on device.
203- # e.g. Flashinfer
204- self ._q_scale_float = 1.0
205- self ._k_scale_float = 1.0
206- self ._v_scale_float = 1.0
207-
208- # The output scale on host memory. This should be the input scale of
209- # the quant op after this attention layer.
210- self ._o_scale_float : float | None = None
250+ # Initialize KV cache quantization attributes
251+ _init_kv_cache_quant (
252+ self , quant_config , prefix , kv_cache_dtype , calculate_kv_scales
253+ )
211254
212255 self .num_heads = num_heads
213256 self .head_size = head_size
214257 self .num_kv_heads = num_kv_heads
215258 self .sliding_window = sliding_window
216259 self .has_sink = extra_impl_args .get ("sinks" ) is not None
217260
218- quant_method = (
219- quant_config .get_quant_method (self , prefix = prefix ) if quant_config else None
220- )
221- if quant_method is not None and not isinstance (
222- quant_method , UnquantizedLinearMethod
223- ):
224- assert isinstance (quant_method , BaseKVCacheMethod )
225- # TODO (mgoin): kv cache dtype should be specified in the FP8
226- # checkpoint config and become the "auto" behavior
227- if self .kv_cache_dtype == "fp8_e5m2" :
228- raise ValueError (
229- "fp8_e5m2 kv-cache is not supported with fp8 checkpoints."
230- )
231- # If quantization is enabled, we make "k_scale" and "v_scale"
232- # parameters so that it can be loaded from the model checkpoint.
233- # The k/v_scale will then be converted back to native float32
234- # values after weight loading.
235- self .quant_method = quant_method
236- self .quant_method .create_weights (self )
237-
238261 # During model initialization, the default dtype is set as the model
239262 # weight and activation dtype.
240263 dtype = torch .get_default_dtype ()
@@ -636,7 +659,11 @@ def __init__(
636659 kv_cache_dtype = "auto"
637660 block_size = 16
638661 calculate_kv_scales = False
639- self .kv_cache_dtype = kv_cache_dtype
662+
663+ # Initialize KV cache quantization attributes
664+ _init_kv_cache_quant (
665+ self , quant_config , prefix , kv_cache_dtype , calculate_kv_scales
666+ )
640667
641668 dtype = torch .get_default_dtype ()
642669 self .attn_backend = get_attn_backend (
@@ -685,20 +712,6 @@ def __init__(
685712 )
686713 ]
687714
688- # Align with Attention's scale attributes for MLA backends.
689-
690- self .calculate_kv_scales = calculate_kv_scales
691- self ._k_scale = torch .tensor (1.0 , dtype = torch .float32 )
692- self ._v_scale = torch .tensor (1.0 , dtype = torch .float32 )
693- self ._q_scale = torch .tensor (1.0 , dtype = torch .float32 )
694- self ._prob_scale = torch .tensor (1.0 , dtype = torch .float32 )
695-
696- # Host-side mirrors used by some attention backends
697- self ._q_scale_float = 1.0
698- self ._k_scale_float = 1.0
699- self ._v_scale_float = 1.0
700- self ._o_scale_float : float | None = None
701-
702715 self .use_sparse = use_sparse
703716
704717 # Initialize q/k/v range constants.
0 commit comments