@@ -13,6 +13,7 @@ class SpecialVocab:
13
13
merges : list [str ]
14
14
add_special_token : dict [str , bool ]
15
15
special_token_ids : dict [str , int ]
16
+ chat_template : str | None
16
17
17
18
def __init__ (
18
19
self , path : str | os .PathLike [str ], load_merges : bool = False ,
@@ -24,6 +25,7 @@ def __init__(
24
25
self .n_vocab = n_vocab
25
26
self .load_merges = load_merges
26
27
self .merges = []
28
+ self .chat_template = None
27
29
if special_token_types is not None :
28
30
self .special_token_types = special_token_types
29
31
else :
@@ -67,6 +69,10 @@ def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None:
67
69
if not quiet :
68
70
print (f'gguf: Setting add_{ typ } _token to { value } ' )
69
71
add_handler (value )
72
+ if self .chat_template is not None :
73
+ if not quiet :
74
+ print (f'gguf: Setting chat_template to { self .chat_template } ' )
75
+ gw .add_chat_template (self .chat_template )
70
76
71
77
def _load (self , path : Path ) -> None :
72
78
self ._try_load_from_tokenizer_json (path )
@@ -132,6 +138,14 @@ def _try_load_from_tokenizer_json(self, path: Path) -> bool:
132
138
return True
133
139
with open (tokenizer_config_file , encoding = 'utf-8' ) as f :
134
140
tokenizer_config = json .load (f )
141
+ chat_template = tokenizer_config .get ('chat_template' )
142
+ if chat_template is None or isinstance (chat_template , str ):
143
+ self .chat_template = chat_template
144
+ else :
145
+ print (
146
+ f'gguf: WARNING: Bad type for chat_template field in { tokenizer_config_file !r} - ignoring' ,
147
+ file = sys .stderr
148
+ )
135
149
for typ in self .special_token_types :
136
150
add_entry = tokenizer_config .get (f'add_{ typ } _token' )
137
151
if isinstance (add_entry , bool ):
0 commit comments