1010
1111import  os 
1212from  dataclasses  import  asdict 
13- from  typing  import  NamedTuple , Optional 
13+ from  typing  import  Any ,  NamedTuple , Optional 
1414
1515from  huggingface_hub  import  snapshot_download 
1616from  transformers  import  AutoTokenizer 
3030
3131class  ModelRequestData (NamedTuple ):
3232    engine_args : EngineArgs 
33-     prompt : str 
33+     prompt : Optional [str ] =  None 
34+     prompt_token_ids : Optional [dict [str , list [int ]]] =  None 
35+     multi_modal_data : Optional [dict [str , Any ]] =  None 
3436    stop_token_ids : Optional [list [int ]] =  None 
3537    lora_requests : Optional [list [LoRARequest ]] =  None 
3638
@@ -40,6 +42,60 @@ class ModelRequestData(NamedTuple):
4042# Unless specified, these settings have been tested to work on a single L4. 
4143
4244
45+ # Voxtral 
46+ def  run_voxtral (question : str , audio_count : int ) ->  ModelRequestData :
47+     from  mistral_common .audio  import  Audio 
48+     from  mistral_common .protocol .instruct .messages  import  (
49+         AudioChunk ,
50+         RawAudio ,
51+         TextChunk ,
52+         UserMessage ,
53+     )
54+     from  mistral_common .protocol .instruct .request  import  ChatCompletionRequest 
55+     from  mistral_common .tokens .tokenizers .mistral  import  MistralTokenizer 
56+ 
57+     model_name  =  "mistralai/Voxtral-Mini-3B-2507" 
58+     tokenizer  =  MistralTokenizer .from_hf_hub (model_name )
59+ 
60+     engine_args  =  EngineArgs (
61+         model = model_name ,
62+         max_model_len = 8192 ,
63+         max_num_seqs = 2 ,
64+         limit_mm_per_prompt = {"audio" : audio_count },
65+         config_format = "mistral" ,
66+         load_format = "mistral" ,
67+         tokenizer_mode = "mistral" ,
68+         enforce_eager = True ,
69+         enable_chunked_prefill = False ,
70+     )
71+ 
72+     text_chunk  =  TextChunk (text = question )
73+     audios  =  [
74+         Audio .from_file (str (audio_assets [i ].get_local_path ()), strict = False )
75+         for  i  in  range (audio_count )
76+     ]
77+     audio_chunks  =  [
78+         AudioChunk (input_audio = RawAudio .from_audio (audio )) for  audio  in  audios 
79+     ]
80+ 
81+     messages  =  [UserMessage (content = [* audio_chunks , text_chunk ])]
82+ 
83+     req  =  ChatCompletionRequest (messages = messages , model = model_name )
84+ 
85+     tokens  =  tokenizer .encode_chat_completion (req )
86+     prompt_ids , audios  =  tokens .tokens , tokens .audios 
87+ 
88+     audios_and_sr  =  [(au .audio_array , au .sampling_rate ) for  au  in  audios ]
89+ 
90+     multi_modal_data  =  {"audio" : audios_and_sr }
91+ 
92+     return  ModelRequestData (
93+         engine_args = engine_args ,
94+         prompt_token_ids = prompt_ids ,
95+         multi_modal_data = multi_modal_data ,
96+     )
97+ 
98+ 
4399# Granite Speech 
44100def  run_granite_speech (question : str , audio_count : int ) ->  ModelRequestData :
45101    # NOTE - the setting in this example are somehat different than what is 
@@ -243,6 +299,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
243299
244300
245301model_example_map  =  {
302+     "voxtral" : run_voxtral ,
246303    "granite_speech" : run_granite_speech ,
247304    "minicpmo" : run_minicpmo ,
248305    "phi4_mm" : run_phi4mm ,
@@ -311,16 +368,24 @@ def main(args):
311368        temperature = 0.2 , max_tokens = 64 , stop_token_ids = req_data .stop_token_ids 
312369    )
313370
314-     mm_data  =  {}
315-     if  audio_count  >  0 :
316-         mm_data  =  {
317-             "audio" : [
318-                 asset .audio_and_sample_rate  for  asset  in  audio_assets [:audio_count ]
319-             ]
320-         }
371+     mm_data  =  req_data .multi_modal_data 
372+     if  not  mm_data :
373+         mm_data  =  {}
374+         if  audio_count  >  0 :
375+             mm_data  =  {
376+                 "audio" : [
377+                     asset .audio_and_sample_rate  for  asset  in  audio_assets [:audio_count ]
378+                 ]
379+             }
321380
322381    assert  args .num_prompts  >  0 
323-     inputs  =  {"prompt" : req_data .prompt , "multi_modal_data" : mm_data }
382+     inputs  =  {"multi_modal_data" : mm_data }
383+ 
384+     if  req_data .prompt :
385+         inputs ["prompt" ] =  req_data .prompt 
386+     else :
387+         inputs ["prompt_token_ids" ] =  req_data .prompt_token_ids 
388+ 
324389    if  args .num_prompts  >  1 :
325390        # Batch inference 
326391        inputs  =  [inputs ] *  args .num_prompts 
0 commit comments