@@ -203,6 +203,8 @@ def from_model_architecture(model_architecture):
203203 return CodeShellModel
204204 if model_architecture == "OrionForCausalLM" :
205205 return OrionModel
206+ if model_architecture == "InternLM2ForCausalLM" :
207+ return InternLM2Model
206208 return Model
207209
208210 def _is_model_safetensors (self ) -> bool :
@@ -254,6 +256,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
254256 return gguf .MODEL_ARCH .CODESHELL
255257 if arch == "OrionForCausalLM" :
256258 return gguf .MODEL_ARCH .ORION
259+ if arch == "InternLM2ForCausalLM" :
260+ return gguf .MODEL_ARCH .INTERNLM2
257261
258262 raise NotImplementedError (f'Architecture "{ arch } " not supported!' )
259263
@@ -1344,6 +1348,153 @@ def write_tensors(self):
13441348 self .gguf_writer .add_tensor ("output.weight" , data )
13451349 print (name , f"=> output.weight, shape = { data .shape } , { old_dtype } --> { data .dtype } " )
13461350
1351+ class InternLM2Model (Model ):
1352+ def set_vocab (self ):
1353+ # (TODO): Is there a better way?
1354+ # Copy from _set_vocab_sentencepiece, The only difference is that we will treat the character
1355+ # \x00 specially and convert it into an emoji character to prevent it from being mistakenly
1356+ # recognized as an empty string in C++.
1357+ from sentencepiece import SentencePieceProcessor
1358+ from sentencepiece import sentencepiece_model_pb2 as model
1359+
1360+ tokenizer_path = self .dir_model / 'tokenizer.model'
1361+
1362+ tokens : list [bytes ] = []
1363+ scores : list [float ] = []
1364+ toktypes : list [int ] = []
1365+
1366+ if not tokenizer_path .is_file ():
1367+ print (f'Error: Missing { tokenizer_path } ' , file = sys .stderr )
1368+ sys .exit (1 )
1369+
1370+ sentencepiece_model = model .ModelProto ()
1371+ sentencepiece_model .ParseFromString (open (tokenizer_path , "rb" ).read ())
1372+ add_prefix = sentencepiece_model .normalizer_spec .add_dummy_prefix
1373+
1374+ tokenizer = SentencePieceProcessor (str (tokenizer_path ))
1375+ vocab_size = self .hparams .get ('vocab_size' , tokenizer .vocab_size ())
1376+
1377+ for token_id in range (vocab_size ):
1378+ piece = tokenizer .id_to_piece (token_id )
1379+ text = piece .encode ("utf-8" )
1380+ score = tokenizer .get_score (token_id )
1381+ if text == b"\x00 " :
1382+ # (TODO): fixme
1383+ # Hack here and replace the \x00 characters.
1384+ print (f"InternLM2 convert token '{ text } ' to '🐉'!" )
1385+ text = "🐉"
1386+
1387+ toktype = SentencePieceTokenTypes .NORMAL
1388+ if tokenizer .is_unknown (token_id ):
1389+ toktype = SentencePieceTokenTypes .UNKNOWN
1390+ elif tokenizer .is_control (token_id ):
1391+ toktype = SentencePieceTokenTypes .CONTROL
1392+ elif tokenizer .is_unused (token_id ):
1393+ toktype = SentencePieceTokenTypes .UNUSED
1394+ elif tokenizer .is_byte (token_id ):
1395+ toktype = SentencePieceTokenTypes .BYTE
1396+
1397+ tokens .append (text )
1398+ scores .append (score )
1399+ toktypes .append (toktype )
1400+
1401+ added_tokens_file = self .dir_model / 'added_tokens.json'
1402+ if added_tokens_file .is_file ():
1403+ with open (added_tokens_file , "r" , encoding = "utf-8" ) as f :
1404+ added_tokens_json = json .load (f )
1405+
1406+ for key in added_tokens_json :
1407+ tokens .append (key .encode ("utf-8" ))
1408+ scores .append (- 1000.0 )
1409+ toktypes .append (SentencePieceTokenTypes .USER_DEFINED )
1410+
1411+ self .gguf_writer .add_tokenizer_model ("llama" )
1412+ self .gguf_writer .add_token_list (tokens )
1413+ self .gguf_writer .add_token_scores (scores )
1414+ self .gguf_writer .add_token_types (toktypes )
1415+ self .gguf_writer .add_add_space_prefix (add_prefix )
1416+
1417+ special_vocab = gguf .SpecialVocab (self .dir_model , n_vocab = len (tokens ))
1418+ special_vocab .add_to_gguf (self .gguf_writer )
1419+
1420+ def set_gguf_parameters (self ):
1421+ self .gguf_writer .add_name ("InternLM2" )
1422+ self .gguf_writer .add_context_length (self .hparams ["max_position_embeddings" ])
1423+ self .gguf_writer .add_block_count (self .hparams ["num_hidden_layers" ])
1424+ self .gguf_writer .add_embedding_length (self .hparams ["hidden_size" ])
1425+ self .gguf_writer .add_feed_forward_length (self .hparams ["intermediate_size" ])
1426+ self .gguf_writer .add_rope_freq_base (self .hparams ["rope_theta" ])
1427+ self .gguf_writer .add_head_count (self .hparams ["num_attention_heads" ])
1428+ self .gguf_writer .add_layer_norm_rms_eps (self .hparams ["rms_norm_eps" ])
1429+ self .gguf_writer .add_head_count_kv (self .hparams ["num_key_value_heads" ])
1430+
1431+ def post_write_tensors (self , tensor_map , name , data_torch ):
1432+ old_dtype = data_torch .dtype
1433+
1434+ # convert any unsupported data types to float32
1435+ if data_torch .dtype not in (torch .float16 , torch .float32 ):
1436+ data_torch = data_torch .to (torch .float32 )
1437+
1438+ data = data_torch .squeeze ().numpy ()
1439+
1440+ # map tensor names
1441+ new_name = tensor_map .get_name (name , try_suffixes = (".weight" , ".bias" ))
1442+ if new_name is None :
1443+ print (f"Can not map tensor { name !r} " )
1444+ sys .exit ()
1445+
1446+ n_dims = len (data .shape )
1447+ data_dtype = data .dtype
1448+
1449+ # if f32 desired, convert any float16 to float32
1450+ if self .ftype == 0 and data_dtype == np .float16 :
1451+ data = data .astype (np .float32 )
1452+
1453+ # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
1454+ if self .ftype == 1 and data_dtype == np .float16 and n_dims == 1 :
1455+ data = data .astype (np .float32 )
1456+
1457+ # if f16 desired, convert any float32 2-dim weight tensors to float16
1458+ if self .ftype == 1 and data_dtype == np .float32 and name .endswith (".weight" ) and n_dims == 2 :
1459+ data = data .astype (np .float16 )
1460+
1461+ print (f"{ new_name } , n_dims = { n_dims } , { old_dtype } --> { data .dtype } " )
1462+ self .gguf_writer .add_tensor (new_name , data )
1463+
1464+ def write_tensors (self ):
1465+ from einops import rearrange
1466+
1467+ num_heads = self .hparams .get ("num_attention_heads" )
1468+ num_kv_heads = self .hparams .get ("num_key_value_heads" )
1469+ hidden_size = self .hparams .get ("hidden_size" )
1470+ q_per_kv = num_heads // num_kv_heads
1471+ head_dim = hidden_size // num_heads
1472+ num_groups = num_heads // q_per_kv
1473+
1474+ block_count = self .hparams ["num_hidden_layers" ]
1475+ model_kv = dict (self .get_tensors ())
1476+ tensor_map = gguf .get_tensor_name_map (self .model_arch , block_count )
1477+ qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv"
1478+ for name , data_torch in model_kv .items ():
1479+ # we don't need these
1480+ if name .endswith (".rotary_emb.inv_freq" ):
1481+ continue
1482+
1483+ if re .match (qkv_pattern , name ):
1484+ bid = re .findall (qkv_pattern , name )[0 ]
1485+ qkv = data_torch
1486+ qkv = rearrange (qkv .T , " o (g n i) ->o g n i" , g = num_groups , n = q_per_kv + 2 , i = head_dim )
1487+ q , k , v = qkv [...,:q_per_kv ,:], qkv [...,q_per_kv :q_per_kv + 1 ,:], qkv [...,q_per_kv + 1 :q_per_kv + 2 ,:]
1488+ q = rearrange (q , " o g n i -> o (g n i)" ).T
1489+ k = rearrange (k , " o g n i -> o (g n i)" ).T
1490+ v = rearrange (v , " o g n i -> o (g n i)" ).T
1491+ self .post_write_tensors (tensor_map , f"model.layers.{ bid } .attention.wq.weight" , q )
1492+ self .post_write_tensors (tensor_map , f"model.layers.{ bid } .attention.wk.weight" , k )
1493+ self .post_write_tensors (tensor_map , f"model.layers.{ bid } .attention.wv.weight" , v )
1494+ else :
1495+ self .post_write_tensors (tensor_map , name , data_torch )
1496+
1497+
13471498###### CONVERSION LOGIC ######
13481499
13491500
0 commit comments