1010
1111import os
1212from dataclasses import asdict
13- from typing import NamedTuple , Optional
13+ from typing import Any , NamedTuple , Optional
1414
1515from huggingface_hub import snapshot_download
1616from transformers import AutoTokenizer
3030
3131class ModelRequestData (NamedTuple ):
3232 engine_args : EngineArgs
33- prompt : str
33+ prompt : Optional [str ] = None
34+ prompt_token_ids : Optional [dict [str , list [int ]]] = None
35+ multi_modal_data : Optional [dict [str , Any ]] = None
3436 stop_token_ids : Optional [list [int ]] = None
3537 lora_requests : Optional [list [LoRARequest ]] = None
3638
@@ -40,6 +42,60 @@ class ModelRequestData(NamedTuple):
4042# Unless specified, these settings have been tested to work on a single L4.
4143
4244
45+ # Voxtral
46+ def run_voxtral (question : str , audio_count : int ) -> ModelRequestData :
47+ from mistral_common .audio import Audio
48+ from mistral_common .protocol .instruct .messages import (
49+ AudioChunk ,
50+ RawAudio ,
51+ TextChunk ,
52+ UserMessage ,
53+ )
54+ from mistral_common .protocol .instruct .request import ChatCompletionRequest
55+ from mistral_common .tokens .tokenizers .mistral import MistralTokenizer
56+
57+ model_name = "mistralai/Voxtral-Mini-3B-2507"
58+ tokenizer = MistralTokenizer .from_hf_hub (model_name )
59+
60+ engine_args = EngineArgs (
61+ model = model_name ,
62+ max_model_len = 8192 ,
63+ max_num_seqs = 2 ,
64+ limit_mm_per_prompt = {"audio" : audio_count },
65+ config_format = "mistral" ,
66+ load_format = "mistral" ,
67+ tokenizer_mode = "mistral" ,
68+ enforce_eager = True ,
69+ enable_chunked_prefill = False ,
70+ )
71+
72+ text_chunk = TextChunk (text = question )
73+ audios = [
74+ Audio .from_file (str (audio_assets [i ].get_local_path ()), strict = False )
75+ for i in range (audio_count )
76+ ]
77+ audio_chunks = [
78+ AudioChunk (input_audio = RawAudio .from_audio (audio )) for audio in audios
79+ ]
80+
81+ messages = [UserMessage (content = [* audio_chunks , text_chunk ])]
82+
83+ req = ChatCompletionRequest (messages = messages , model = model_name )
84+
85+ tokens = tokenizer .encode_chat_completion (req )
86+ prompt_ids , audios = tokens .tokens , tokens .audios
87+
88+ audios_and_sr = [(au .audio_array , au .sampling_rate ) for au in audios ]
89+
90+ multi_modal_data = {"audio" : audios_and_sr }
91+
92+ return ModelRequestData (
93+ engine_args = engine_args ,
94+ prompt_token_ids = prompt_ids ,
95+ multi_modal_data = multi_modal_data ,
96+ )
97+
98+
4399# Granite Speech
44100def run_granite_speech (question : str , audio_count : int ) -> ModelRequestData :
45101 # NOTE - the setting in this example are somehat different than what is
@@ -243,6 +299,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
243299
244300
245301model_example_map = {
302+ "voxtral" : run_voxtral ,
246303 "granite_speech" : run_granite_speech ,
247304 "minicpmo" : run_minicpmo ,
248305 "phi4_mm" : run_phi4mm ,
@@ -311,16 +368,24 @@ def main(args):
311368 temperature = 0.2 , max_tokens = 64 , stop_token_ids = req_data .stop_token_ids
312369 )
313370
314- mm_data = {}
315- if audio_count > 0 :
316- mm_data = {
317- "audio" : [
318- asset .audio_and_sample_rate for asset in audio_assets [:audio_count ]
319- ]
320- }
371+ mm_data = req_data .multi_modal_data
372+ if not mm_data :
373+ mm_data = {}
374+ if audio_count > 0 :
375+ mm_data = {
376+ "audio" : [
377+ asset .audio_and_sample_rate for asset in audio_assets [:audio_count ]
378+ ]
379+ }
321380
322381 assert args .num_prompts > 0
323- inputs = {"prompt" : req_data .prompt , "multi_modal_data" : mm_data }
382+ inputs = {"multi_modal_data" : mm_data }
383+
384+ if req_data .prompt :
385+ inputs ["prompt" ] = req_data .prompt
386+ else :
387+ inputs ["prompt_token_ids" ] = req_data .prompt_token_ids
388+
324389 if args .num_prompts > 1 :
325390 # Batch inference
326391 inputs = [inputs ] * args .num_prompts
0 commit comments