@@ -667,16 +667,24 @@ def __init__(
667667 eps = config .rms_norm_eps )
668668 if config .attention_type == 0 :
669669 self .layernorm_attention_alpha = getattr (
670- config , 'layernorm_linear_attention_alpha' , 1 )
670+ config , 'layernorm_linear_attention_alpha' ,
671+ getattr (config , 'linear_attn_alpha_factor' , 1 ))
671672 self .layernorm_attention_beta = getattr (
672- config , 'layernorm_linear_attention_beta' , 1 )
673+ config , 'layernorm_linear_attention_beta' ,
674+ getattr (config , 'linear_attn_beta_factor' , 1 ))
673675 else :
674676 self .layernorm_attention_alpha = getattr (
675- config , 'layernorm_full_attention_alpha' , 1 )
677+ config , 'layernorm_full_attention_alpha' ,
678+ getattr (config , 'full_attn_alpha_factor' , 1 ))
676679 self .layernorm_attention_beta = getattr (
677- config , 'layernorm_full_attention_beta' , 1 )
678- self .layernorm_mlp_alpha = getattr (config , 'layernorm_mlp_alpha' , 1 )
679- self .layernorm_mlp_beta = getattr (config , 'layernorm_mlp_beta' , 1 )
680+ config , 'layernorm_full_attention_beta' ,
681+ getattr (config , 'full_attn_beta_factor' , 1 ))
682+ self .layernorm_mlp_alpha = getattr (
683+ config , 'layernorm_mlp_alpha' ,
684+ getattr (config , 'mlp_alpha_factor' , 1 ))
685+ self .layernorm_mlp_beta = getattr (
686+ config , 'layernorm_mlp_beta' , getattr (config , 'mlp_beta_factor' ,
687+ 1 ))
680688 self .postnorm = getattr (config , 'postnorm' , False )
681689 self .shared_moe = False
682690
@@ -794,6 +802,18 @@ def __init__(
794802 self .decoder_attention_types = getattr (
795803 config , "attn_type_list" , False ) or getattr (
796804 config , "decoder_attention_types" , False )
805+ # The HF format uses "layer_types" instead of "attn_type_list"
806+ # where "linear_attention" is 0 and "full_attention" is 1
807+ if not self .decoder_attention_types and hasattr (config , "layer_types" ):
808+ self .decoder_attention_types = []
809+ for layer_type in config .layer_types :
810+ if layer_type == "linear_attention" :
811+ self .decoder_attention_types .append (0 )
812+ elif layer_type == "full_attention" :
813+ self .decoder_attention_types .append (1 )
814+ else :
815+ raise ValueError (f"Unsupported layer type: { layer_type } " )
816+ # Default to full attention
797817 if not self .decoder_attention_types :
798818 self .decoder_attention_types = [1 ] * config .num_hidden_layers
799819 self .num_layers = config .num_hidden_layers
@@ -1022,8 +1042,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
10221042 else :
10231043 self .lm_head = PPMissingLayer ()
10241044 self .lm_head .float ()
1025- flash_layer_count = sum (1 for attn_type in self .config .attn_type_list
1026- if attn_type == 1 )
1045+ flash_layer_count = sum (
1046+ 1 for attn_type in self .model .decoder_attention_types
1047+ if attn_type == 1 )
10271048 self .kv_cache = [torch .tensor ([]) for _ in range (flash_layer_count )]
10281049 return
10291050
@@ -1085,9 +1106,10 @@ def which_layer(name: str) -> int:
10851106 return None
10861107
10871108 def is_linear_attn_layer (layer_idx : int ) -> bool :
1088- if layer_idx is None or not hasattr (self .config , "attn_type_list" ):
1109+ if layer_idx is None or layer_idx >= len (
1110+ self .model .decoder_attention_types ):
10891111 return False
1090- return self .config . attn_type_list [layer_idx ] == 0
1112+ return self .model . decoder_attention_types [layer_idx ] == 0
10911113
10921114 def is_moe_weight (name : str ) -> bool :
10931115 return "block_sparse_moe" in name and not name .endswith (".bias" )
@@ -1275,7 +1297,7 @@ def load_basic_weight(name: str, loaded_weight: torch.Tensor,
12751297 for name , loaded_weight in weights :
12761298 weight_at_layer = which_layer (name )
12771299 if weight_at_layer and weight_at_layer >= len (
1278- self .config . attn_type_list ):
1300+ self .model . decoder_attention_types ):
12791301 continue
12801302
12811303 if is_layer_norm_weight (name ):
0 commit comments