@@ -118,76 +118,97 @@ class MODEL_TENSOR(IntEnum):
118
118
MODEL_ARCH .STARCODER : "starcoder" ,
119
119
}
120
120
121
- MODEL_TENSOR_NAMES : dict [MODEL_ARCH , dict [MODEL_TENSOR , str ]] = {
122
- MODEL_ARCH .LLAMA : {
123
- MODEL_TENSOR .TOKEN_EMBD : "token_embd" ,
124
- MODEL_TENSOR .OUTPUT_NORM : "output_norm" ,
125
- MODEL_TENSOR .OUTPUT : "output" ,
126
- MODEL_TENSOR .ROPE_FREQS : "rope_freqs" ,
127
- MODEL_TENSOR .ATTN_NORM : "blk.{bid}.attn_norm" ,
128
- MODEL_TENSOR .ATTN_Q : "blk.{bid}.attn_q" ,
129
- MODEL_TENSOR .ATTN_K : "blk.{bid}.attn_k" ,
130
- MODEL_TENSOR .ATTN_V : "blk.{bid}.attn_v" ,
131
- MODEL_TENSOR .ATTN_OUT : "blk.{bid}.attn_output" ,
132
- MODEL_TENSOR .ATTN_ROT_EMBD : "blk.{bid}.attn_rot_embd" ,
133
- MODEL_TENSOR .FFN_NORM : "blk.{bid}.ffn_norm" ,
134
- MODEL_TENSOR .FFN_GATE : "blk.{bid}.ffn_gate" ,
135
- MODEL_TENSOR .FFN_DOWN : "blk.{bid}.ffn_down" ,
136
- MODEL_TENSOR .FFN_UP : "blk.{bid}.ffn_up" ,
137
- },
138
- MODEL_ARCH .GPTNEOX : {
139
- MODEL_TENSOR .TOKEN_EMBD : "token_embd" ,
140
- MODEL_TENSOR .OUTPUT_NORM : "output_norm" ,
141
- MODEL_TENSOR .OUTPUT : "output" ,
142
- MODEL_TENSOR .ATTN_NORM : "blk.{bid}.attn_norm" ,
143
- MODEL_TENSOR .ATTN_QKV : "blk.{bid}.attn_qkv" ,
144
- MODEL_TENSOR .ATTN_OUT : "blk.{bid}.attn_output" ,
145
- MODEL_TENSOR .FFN_NORM : "blk.{bid}.ffn_norm" ,
146
- MODEL_TENSOR .FFN_DOWN : "blk.{bid}.ffn_down" ,
147
- MODEL_TENSOR .FFN_UP : "blk.{bid}.ffn_up" ,
148
- },
149
- MODEL_ARCH .FALCON : {
150
- MODEL_TENSOR .TOKEN_EMBD : "token_embd" ,
151
- MODEL_TENSOR .OUTPUT_NORM : "output_norm" ,
152
- MODEL_TENSOR .OUTPUT : "output" ,
153
- MODEL_TENSOR .ATTN_NORM : "blk.{bid}.attn_norm" ,
154
- MODEL_TENSOR .ATTN_NORM_2 : "blk.{bid}.attn_norm_2" ,
155
- MODEL_TENSOR .ATTN_QKV : "blk.{bid}.attn_qkv" ,
156
- MODEL_TENSOR .ATTN_OUT : "blk.{bid}.attn_output" ,
157
- MODEL_TENSOR .FFN_DOWN : "blk.{bid}.ffn_down" ,
158
- MODEL_TENSOR .FFN_UP : "blk.{bid}.ffn_up" ,
159
- },
160
- MODEL_ARCH .BAICHUAN : {
161
- MODEL_TENSOR .TOKEN_EMBD : "token_embd" ,
162
- MODEL_TENSOR .OUTPUT_NORM : "output_norm" ,
163
- MODEL_TENSOR .OUTPUT : "output" ,
164
- MODEL_TENSOR .ROPE_FREQS : "rope_freqs" ,
165
- MODEL_TENSOR .ATTN_NORM : "blk.{bid}.attn_norm" ,
166
- MODEL_TENSOR .ATTN_Q : "blk.{bid}.attn_q" ,
167
- MODEL_TENSOR .ATTN_K : "blk.{bid}.attn_k" ,
168
- MODEL_TENSOR .ATTN_V : "blk.{bid}.attn_v" ,
169
- MODEL_TENSOR .ATTN_OUT : "blk.{bid}.attn_output" ,
170
- MODEL_TENSOR .ATTN_ROT_EMBD : "blk.{bid}.attn_rot_embd" ,
171
- MODEL_TENSOR .FFN_NORM : "blk.{bid}.ffn_norm" ,
172
- MODEL_TENSOR .FFN_GATE : "blk.{bid}.ffn_gate" ,
173
- MODEL_TENSOR .FFN_DOWN : "blk.{bid}.ffn_down" ,
174
- MODEL_TENSOR .FFN_UP : "blk.{bid}.ffn_up" ,
175
- },
176
- MODEL_ARCH .STARCODER : {
177
- MODEL_TENSOR .TOKEN_EMBD : "token_embd" ,
178
- MODEL_TENSOR .POS_EMBD : "position_embd" ,
179
- MODEL_TENSOR .OUTPUT_NORM : "output_norm" ,
180
- MODEL_TENSOR .OUTPUT : "output" ,
181
- MODEL_TENSOR .ATTN_NORM : "blk.{bid}.attn_norm" ,
182
- MODEL_TENSOR .ATTN_QKV : "blk.{bid}.attn_qkv" ,
183
- MODEL_TENSOR .ATTN_OUT : "blk.{bid}.attn_output" ,
184
- MODEL_TENSOR .FFN_NORM : "blk.{bid}.ffn_norm" ,
185
- MODEL_TENSOR .FFN_DOWN : "blk.{bid}.ffn_down" ,
186
- MODEL_TENSOR .FFN_UP : "blk.{bid}.ffn_up" ,
187
- },
188
- MODEL_ARCH .GPT2 : {
121
+ TENSOR_NAMES : dict [MODEL_TENSOR , str ] = {
122
+ MODEL_TENSOR .TOKEN_EMBD : "token_embd" ,
123
+ MODEL_TENSOR .POS_EMBD : "position_embd" ,
124
+ MODEL_TENSOR .OUTPUT_NORM : "output_norm" ,
125
+ MODEL_TENSOR .OUTPUT : "output" ,
126
+ MODEL_TENSOR .ROPE_FREQS : "rope_freqs" ,
127
+
128
+ MODEL_TENSOR .ATTN_NORM : "blk.{bid}.attn_norm" ,
129
+ MODEL_TENSOR .ATTN_NORM_2 : "blk.{bid}.attn_norm_2" ,
130
+ MODEL_TENSOR .ATTN_QKV : "blk.{bid}.attn_qkv" ,
131
+ MODEL_TENSOR .ATTN_Q : "blk.{bid}.attn_q" ,
132
+ MODEL_TENSOR .ATTN_K : "blk.{bid}.attn_k" ,
133
+ MODEL_TENSOR .ATTN_V : "blk.{bid}.attn_v" ,
134
+ MODEL_TENSOR .ATTN_OUT : "blk.{bid}.attn_output" ,
135
+ MODEL_TENSOR .ATTN_ROT_EMBD : "blk.{bid}.attn_rot_embd" ,
136
+ MODEL_TENSOR .FFN_NORM : "blk.{bid}.ffn_norm" ,
137
+ MODEL_TENSOR .FFN_GATE : "blk.{bid}.ffn_gate" ,
138
+ MODEL_TENSOR .FFN_DOWN : "blk.{bid}.ffn_down" ,
139
+ MODEL_TENSOR .FFN_UP : "blk.{bid}.ffn_up" ,
140
+ }
141
+
142
+ MODEL_TENSORS : dict [MODEL_ARCH , list [MODEL_TENSOR ]] = {
143
+ MODEL_ARCH .LLAMA : [
144
+ MODEL_TENSOR .TOKEN_EMBD ,
145
+ MODEL_TENSOR .OUTPUT_NORM ,
146
+ MODEL_TENSOR .OUTPUT ,
147
+ MODEL_TENSOR .ROPE_FREQS ,
148
+ MODEL_TENSOR .ATTN_NORM ,
149
+ MODEL_TENSOR .ATTN_Q ,
150
+ MODEL_TENSOR .ATTN_K ,
151
+ MODEL_TENSOR .ATTN_V ,
152
+ MODEL_TENSOR .ATTN_OUT ,
153
+ MODEL_TENSOR .ATTN_ROT_EMBD ,
154
+ MODEL_TENSOR .FFN_NORM ,
155
+ MODEL_TENSOR .FFN_GATE ,
156
+ MODEL_TENSOR .FFN_DOWN ,
157
+ MODEL_TENSOR .FFN_UP ,
158
+ ],
159
+ MODEL_ARCH .GPTNEOX : [
160
+ MODEL_TENSOR .TOKEN_EMBD ,
161
+ MODEL_TENSOR .OUTPUT_NORM ,
162
+ MODEL_TENSOR .OUTPUT ,
163
+ MODEL_TENSOR .ATTN_NORM ,
164
+ MODEL_TENSOR .ATTN_QKV ,
165
+ MODEL_TENSOR .ATTN_OUT ,
166
+ MODEL_TENSOR .FFN_NORM ,
167
+ MODEL_TENSOR .FFN_DOWN ,
168
+ MODEL_TENSOR .FFN_UP ,
169
+ ],
170
+ MODEL_ARCH .FALCON : [
171
+ MODEL_TENSOR .TOKEN_EMBD ,
172
+ MODEL_TENSOR .OUTPUT_NORM ,
173
+ MODEL_TENSOR .OUTPUT ,
174
+ MODEL_TENSOR .ATTN_NORM ,
175
+ MODEL_TENSOR .ATTN_NORM_2 ,
176
+ MODEL_TENSOR .ATTN_QKV ,
177
+ MODEL_TENSOR .ATTN_OUT ,
178
+ MODEL_TENSOR .FFN_DOWN ,
179
+ MODEL_TENSOR .FFN_UP ,
180
+ ],
181
+ MODEL_ARCH .BAICHUAN : [
182
+ MODEL_TENSOR .TOKEN_EMBD ,
183
+ MODEL_TENSOR .OUTPUT_NORM ,
184
+ MODEL_TENSOR .OUTPUT ,
185
+ MODEL_TENSOR .ROPE_FREQS ,
186
+ MODEL_TENSOR .ATTN_NORM ,
187
+ MODEL_TENSOR .ATTN_Q ,
188
+ MODEL_TENSOR .ATTN_K ,
189
+ MODEL_TENSOR .ATTN_V ,
190
+ MODEL_TENSOR .ATTN_OUT ,
191
+ MODEL_TENSOR .ATTN_ROT_EMBD ,
192
+ MODEL_TENSOR .FFN_NORM ,
193
+ MODEL_TENSOR .FFN_GATE ,
194
+ MODEL_TENSOR .FFN_DOWN ,
195
+ MODEL_TENSOR .FFN_UP ,
196
+ ],
197
+ MODEL_ARCH .STARCODER : [
198
+ MODEL_TENSOR .TOKEN_EMBD ,
199
+ MODEL_TENSOR .POS_EMBD ,
200
+ MODEL_TENSOR .OUTPUT_NORM ,
201
+ MODEL_TENSOR .OUTPUT ,
202
+ MODEL_TENSOR .ATTN_NORM ,
203
+ MODEL_TENSOR .ATTN_QKV ,
204
+ MODEL_TENSOR .ATTN_OUT ,
205
+ MODEL_TENSOR .FFN_NORM ,
206
+ MODEL_TENSOR .FFN_DOWN ,
207
+ MODEL_TENSOR .FFN_UP ,
208
+ ],
209
+ MODEL_ARCH .GPT2 : [
189
210
# TODO
190
- } ,
211
+ ] ,
191
212
# TODO
192
213
}
193
214
@@ -338,28 +359,24 @@ class TensorNameMap:
338
359
339
360
mapping : dict [str , tuple [MODEL_TENSOR , str ]]
340
361
341
- tensor_names : dict [MODEL_TENSOR , str ]
342
-
343
362
def __init__ (self , arch : MODEL_ARCH , n_blocks : int ):
344
- mapping = self .mapping = {}
345
- tensor_names = self .tensor_names = MODEL_TENSOR_NAMES [arch ]
363
+ self .mapping = {}
346
364
for tensor , keys in self .mappings_cfg .items ():
347
- tensor_name = tensor_names .get (tensor )
348
- if tensor_name is None :
365
+ if tensor not in MODEL_TENSORS [arch ]:
349
366
continue
350
- mapping [tensor_name ] = (tensor , tensor_name )
367
+ tensor_name = TENSOR_NAMES [tensor ]
368
+ self .mapping [tensor_name ] = (tensor , tensor_name )
351
369
for key in keys :
352
- mapping [key ] = (tensor , tensor_name )
370
+ self . mapping [key ] = (tensor , tensor_name )
353
371
for bid in range (n_blocks ):
354
372
for tensor , keys in self .block_mappings_cfg .items ():
355
- tensor_name = tensor_names .get (tensor )
356
- if tensor_name is None :
373
+ if tensor not in MODEL_TENSORS [arch ]:
357
374
continue
358
- tensor_name = tensor_name .format (bid = bid )
359
- mapping [tensor_name ] = (tensor , tensor_name )
375
+ tensor_name = TENSOR_NAMES [ tensor ] .format (bid = bid )
376
+ self . mapping [tensor_name ] = (tensor , tensor_name )
360
377
for key in keys :
361
378
key = key .format (bid = bid )
362
- mapping [key ] = (tensor , tensor_name )
379
+ self . mapping [key ] = (tensor , tensor_name )
363
380
364
381
def get_type_and_name (self , key : str , try_suffixes : Sequence [str ] = ()) -> tuple [MODEL_TENSOR , str ] | None :
365
382
result = self .mapping .get (key )
@@ -800,22 +817,25 @@ class SpecialVocab:
800
817
special_token_types : tuple [str , ...] = ('bos' , 'eos' , 'unk' , 'sep' , 'pad' )
801
818
special_token_ids : dict [str , int ] = {}
802
819
803
- def __init__ (self , path : Path , load_merges : bool = False , special_token_types : tuple [str , ...] | None = None ):
820
+ def __init__ (
821
+ self , path : str | os .PathLike [str ], load_merges : bool = False ,
822
+ special_token_types : tuple [str , ...] | None = None ,
823
+ ):
804
824
self .special_token_ids = {}
805
825
self .load_merges = load_merges
806
826
if special_token_types is not None :
807
827
self .special_token_types = special_token_types
808
- self .load ( path )
828
+ self ._load ( Path ( path ) )
809
829
810
- def load (self , path : Path ):
811
- if not self .try_load_from_tokenizer_json (path ):
812
- self .try_load_from_config_json (path )
830
+ def _load (self , path : Path ) -> None :
831
+ if not self ._try_load_from_tokenizer_json (path ):
832
+ self ._try_load_from_config_json (path )
813
833
814
- def try_load_from_tokenizer_json (self , path : Path ) -> bool :
834
+ def _try_load_from_tokenizer_json (self , path : Path ) -> bool :
815
835
tokenizer_file = path / 'tokenizer.json'
816
836
if not tokenizer_file .is_file ():
817
837
return False
818
- with open (tokenizer_file , 'r' , encoding = 'utf-8' ) as f :
838
+ with open (tokenizer_file , encoding = 'utf-8' ) as f :
819
839
tokenizer = json .load (f )
820
840
if self .load_merges :
821
841
merges = tokenizer .get ('model' , {}).get ('merges' )
@@ -825,7 +845,7 @@ def try_load_from_tokenizer_json(self, path: Path) -> bool:
825
845
added_tokens = tokenizer .get ('added_tokens' )
826
846
if added_tokens is None or not tokenizer_config_file .is_file ():
827
847
return True
828
- with open (tokenizer_config_file , 'r' , encoding = 'utf-8' ) as f :
848
+ with open (tokenizer_config_file , encoding = 'utf-8' ) as f :
829
849
tokenizer_config = json .load (f )
830
850
for typ in self .special_token_types :
831
851
entry = tokenizer_config .get (f'{ typ } _token' )
@@ -844,19 +864,19 @@ def try_load_from_tokenizer_json(self, path: Path) -> bool:
844
864
break
845
865
return True
846
866
847
- def try_load_from_config_json (self , path : Path ) -> bool :
867
+ def _try_load_from_config_json (self , path : Path ) -> bool :
848
868
config_file = path / 'config.json'
849
869
if not config_file .is_file ():
850
870
return False
851
- with open (config_file , 'r' , encoding = 'utf-8' ) as f :
871
+ with open (config_file , encoding = 'utf-8' ) as f :
852
872
config = json .load (f )
853
873
for typ in self .special_token_types :
854
874
maybe_token_id = config .get (f'{ typ } _token_id' )
855
875
if isinstance (maybe_token_id , int ) and maybe_token_id >= 0 :
856
876
self .special_token_ids [typ ] = maybe_token_id
857
877
return True
858
878
859
- def add_to_gguf (self , gw : GGUFWriter ):
879
+ def add_to_gguf (self , gw : GGUFWriter ) -> None :
860
880
if len (self .merges ) > 0 :
861
881
print (f'gguf: Adding { len (self .merges )} merge(s).' )
862
882
gw .add_token_merges (self .merges )
@@ -868,8 +888,8 @@ def add_to_gguf(self, gw: GGUFWriter):
868
888
print (f'gguf: Setting special token type { typ } to { tokid } ' )
869
889
handler (tokid )
870
890
871
- def __repr__ (self ):
872
- return f'<SpecialVocab with { len (self .merges )} merges and special tokens { self .special_token_ids if self . special_token_ids else "unset" } >'
891
+ def __repr__ (self ) -> str :
892
+ return f'<SpecialVocab with { len (self .merges )} merges and special tokens { self .special_token_ids or "unset" } >'
873
893
874
894
875
895
# Example usage:
0 commit comments