1212
1313from vllm .config import VllmConfig
1414from vllm .logger import init_logger
15- from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
16- QKVParallelLinear )
17- from vllm .model_executor .layers .quantization .base_config import (
18- QuantizationConfig )
1915from vllm .model_executor .model_loader .weight_utils import default_weight_loader
20- from vllm .model_executor .utils import set_weight_attrs
2116from vllm .multimodal import MultiModalPlaceholderMap , NestedTensors
2217from vllm .sequence import IntermediateTensors
2318from vllm .utils import is_pin_memory_available
@@ -655,98 +650,4 @@ def cast_overflow_tensors(
655650 if tensors .isinf ().any () or tensors .isnan ().any ():
656651 clamp_value = torch .finfo (tensors .dtype ).max - offset
657652 tensors = torch .clamp (tensors , min = - clamp_value , max = clamp_value )
658- return tensors
659-
660- class QKVCrossParallelLinear (torch .nn .Module ):
661-
662- def __init__ (self ,
663- hidden_size : int ,
664- head_size : int ,
665- total_num_heads : int ,
666- total_num_kv_heads : Optional [int ] = None ,
667- bias : bool = True ,
668- skip_bias_add : bool = False ,
669- params_dtype : Optional [torch .dtype ] = None ,
670- quant_config : Optional [QuantizationConfig ] = None ,
671- prefix : str = "" ):
672- super ().__init__ ()
673- # Empty placeholders for loading as a single module.
674- self .weight = torch .nn .Parameter ()
675- set_weight_attrs (self .weight , {
676- "weight_loader" : self .weight_loader_weight ,
677- })
678- # Use a dictionary to avoid submodules parameters auto-registration:
679- # drop-in replacement for a `QKVParallelLinear` module.
680- self .proj = dict ()
681- self .proj ["q_proj_decoder" ] = ColumnParallelLinear (
682- input_size = hidden_size ,
683- output_size = total_num_heads * head_size ,
684- bias = bias ,
685- quant_config = quant_config ,
686- skip_bias_add = skip_bias_add ,
687- params_dtype = params_dtype ,
688- prefix = f"{ prefix } .q_proj_decoder" )
689-
690- self .proj ["kv_proj_encoder" ] = QKVParallelLinear (
691- hidden_size = hidden_size ,
692- head_size = head_size ,
693- total_num_heads = 0 ,
694- total_num_kv_heads = total_num_kv_heads ,
695- bias = bias ,
696- quant_config = quant_config ,
697- skip_bias_add = skip_bias_add ,
698- params_dtype = params_dtype ,
699- prefix = f"{ prefix } .kv_proj_encoder" )
700-
701- # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
702- self .kv_size = self .kv_proj_encoder .num_kv_heads * head_size
703-
704- if bias :
705- self .bias = torch .nn .Parameter ()
706- set_weight_attrs (self .bias , {
707- "weight_loader" : self .weight_loader_bias ,
708- })
709-
710- @property
711- def q_proj_decoder (self ):
712- return self .proj ["q_proj_decoder" ]
713-
714- @property
715- def kv_proj_encoder (self ):
716- return self .proj ["kv_proj_encoder" ]
717-
718- def forward (self , decoder_hidden_states , encoder_hidden_states ):
719- q , _ = self .q_proj_decoder (decoder_hidden_states )
720- if encoder_hidden_states is None :
721- # Encoder KV already cached.
722- k = None
723- v = None
724- else :
725- # Prefill phase, encoder KV cached here.
726- kv_enc , _ = self .kv_proj_encoder (encoder_hidden_states )
727- # Split kv in half
728- k , v = kv_enc .split (self .kv_size , dim = - 1 )
729- return q , k , v
730-
731- def weight_loader_weight (self ,
732- param : torch .nn .Parameter ,
733- loaded_weight : torch .Tensor ,
734- loaded_shard_id : Optional [str ] = None ):
735- # NOTE Use QKV/ColumnParallel weight_loader, ignore placeholder param.
736- param = self .q_proj_decoder .weight if loaded_shard_id == "q" \
737- else self .kv_proj_encoder .weight
738- param .weight_loader (
739- param ,
740- loaded_weight ) if loaded_shard_id == "q" else param .weight_loader (
741- param , loaded_weight , loaded_shard_id )
742-
743- def weight_loader_bias (self ,
744- param : torch .nn .Parameter ,
745- loaded_weight : torch .Tensor ,
746- loaded_shard_id : Optional [str ] = None ):
747- param = self .q_proj_decoder .bias if loaded_shard_id == "q" \
748- else self .kv_proj_encoder .bias
749- param .weight_loader (
750- param ,
751- loaded_weight ) if loaded_shard_id == "q" else param .weight_loader (
752- param , loaded_weight , loaded_shard_id )
653+ return tensors
0 commit comments