@@ -104,6 +104,42 @@ def find_tokenizer_file(files: List[str]):
104104 return matched_files [0 ]
105105
106106
107+ def make_mistral_chat_completion_request (
108+ messages : List ["ChatCompletionMessageParam" ],
109+ tools : Optional [List [Dict [str ,
110+ Any ]]] = None ) -> "ChatCompletionRequest" :
111+ last_message = cast (Dict [str , Any ], messages [- 1 ])
112+ if last_message ["role" ] == "assistant" :
113+ last_message ["prefix" ] = True
114+
115+ last_message = cast (Dict [str , Any ], messages [- 1 ])
116+ if last_message ["role" ] == "assistant" :
117+ last_message ["prefix" ] = True
118+
119+ # mistral-common requires AssistantMessage content to be string [1].
120+ #
121+ # [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
122+ for message in messages :
123+ if message .get ("role" ) == "assistant" :
124+ content = message .get ("content" )
125+ if isinstance (content , list ):
126+ content = "\n " .join (chunk .get ("text" ) for chunk in content )
127+ message ["content" ] = content
128+
129+ # The Mistral client, in comparison to the OpenAI client, requires the
130+ # "parameters" dict to be present, even if it's empty.
131+ if tools :
132+ for function in [
133+ tool ["function" ] for tool in tools
134+ if tool ["type" ] == "function"
135+ ]:
136+ function .setdefault ("parameters" , {})
137+
138+ from mistral_common .protocol .instruct .request import ChatCompletionRequest
139+ return ChatCompletionRequest (messages = messages ,
140+ tools = tools ) # type: ignore[type-var]
141+
142+
107143class MistralTokenizer :
108144
109145 def __init__ (self , tokenizer : "PublicMistralTokenizer" ) -> None :
@@ -283,27 +319,10 @@ def encode(self, prompt: str) -> List[int]:
283319
284320 def apply_chat_template (self ,
285321 messages : List ["ChatCompletionMessageParam" ],
286- tools : Optional [Dict [str , Any ]] = None ,
322+ tools : Optional [List [ Dict [str , Any ] ]] = None ,
287323 ** kwargs ) -> List [int ]:
288324
289- last_message = cast (Dict [str , Any ], messages [- 1 ])
290- if last_message ["role" ] == "assistant" :
291- last_message ["prefix" ] = True
292-
293- from mistral_common .protocol .instruct .request import (
294- ChatCompletionRequest )
295-
296- # mistral-common requires AssistantMessage content to be string [1].
297- #
298- # [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
299- for message in messages :
300- if message .get ("role" ) == "assistant" :
301- content = message .get ("content" )
302- if isinstance (content , list ):
303- content = "\n " .join (chunk .get ("text" ) for chunk in content )
304- message ["content" ] = content
305- request = ChatCompletionRequest (messages = messages ,
306- tools = tools ) # type: ignore[type-var]
325+ request = make_mistral_chat_completion_request (messages , tools )
307326 encoded = self .mistral .encode_chat_completion (request )
308327
309328 # encode-decode to get clean prompt
0 commit comments