2323import typing
2424import warnings
2525from pathlib import Path
26- from typing import Any , Callable , Optional , TypedDict , Union
26+ from typing import Any , Callable , Dict , List , Optional , TypedDict , Union
2727
2828import numpy as np
2929import typing_extensions
@@ -386,14 +386,10 @@ class TokenizerChatTemplateKwargs(TypedDict, total=False):
386386 return_assistant_tokens_mask : Optional [bool ] = False
387387
388388
389- class ProcessorChatTemplateKwargs ( TokenizerChatTemplateKwargs , total = False ):
389+ class ChatTemplateLoadKwargs ( TypedDict , total = False ):
390390 """
391- Keyword arguments for processor chat templates.
391+ Keyword arguments used to load multimodal data in processor chat templates.
392392
393- tokenize (`bool`, *optional*, defaults to `False`):
394- Whether to tokenize the output or not.
395- return_dict (`bool`, defaults to `False`):
396- Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
397393 num_frames (`int`, *optional*):
398394 Number of frames to sample uniformly. If not passed, the whole video is loaded.
399395 video_load_backend (`str`, *optional*, defaults to `"pyav"`):
@@ -415,13 +411,26 @@ def sample_indices_fn(num_frames, fps, metadata, **kwargs):
415411 return np.linspace(start_idx, end_idx, num_frames, dtype=int)
416412 """
417413
418- tokenize : Optional [bool ] = False
419- return_dict : Optional [bool ] = False
420414 num_frames : Optional [int ] = None
421415 video_load_backend : Optional [str ] = "pyav"
422416 video_fps : Optional [int ] = None
423417 sampling_rate : Optional [int ] = 16_000
424418 sample_indices_fn : Optional [Callable ] = None
419+ load_audio_from_video : Optional [bool ] = False
420+
421+
422+ class ProcessorChatTemplateKwargs (ChatTemplateLoadKwargs , TokenizerChatTemplateKwargs , total = False ):
423+ """
424+ Keyword arguments for processor's `apply_chat_template`.
425+
426+ tokenize (`bool`, *optional*, defaults to `False`):
427+ Whether to tokenize the output or not.
428+ return_dict (`bool`, defaults to `False`):
429+ Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
430+ """
431+
432+ tokenize : Optional [bool ] = False
433+ return_dict : Optional [bool ] = False
425434
426435
427436class AllKwargsForChatTemplate (
@@ -1236,11 +1245,11 @@ def __call__(
12361245
12371246 def _process_messages_for_chat_template (
12381247 self ,
1239- conversation : list [ list [ dict [str , str ]]],
1240- batch_images : list [ImageInput ],
1241- batch_videos : list [VideoInput ],
1242- batch_video_metadata : list [ list [ dict [str , any ]]],
1243- ** chat_template_kwargs : Unpack [AllKwargsForChatTemplate ],
1248+ conversation : List [ List [ Dict [str , str ]]],
1249+ batch_images : List [ImageInput ],
1250+ batch_videos : List [VideoInput ],
1251+ batch_video_metadata : List [ List [ Dict [str , any ]]],
1252+ ** mm_load_kwargs : Unpack [ChatTemplateLoadKwargs ],
12441253 ):
12451254 """
12461255 Used within `apply_chat_template` when a model has a special way to process conversation history. For example,
@@ -1311,18 +1320,18 @@ def apply_chat_template(
13111320 )
13121321
13131322 # Fill two sets of kwargs that should be used by tokenizer's `apply_chat_template`
1314- # and for multimodal chat template
1323+ # and for multimodal data loading. Everything else will be used in `__call__`
13151324 tokenizer_template_kwargs = {}
13161325 for tokenizer_key in TokenizerChatTemplateKwargs .__annotations__ .keys ():
1317- tokenizer_value = getattr (TokenizerChatTemplateKwargs , tokenizer_key , None )
1318- value = kwargs .pop (tokenizer_key , tokenizer_value )
1326+ default_value = getattr (TokenizerChatTemplateKwargs , tokenizer_key , None )
1327+ value = kwargs .pop (tokenizer_key , default_value )
13191328 tokenizer_template_kwargs [tokenizer_key ] = value
13201329
1321- chat_template_kwargs = {}
1322- for key in ProcessorChatTemplateKwargs .__annotations__ .keys ():
1323- processor_value = getattr (ProcessorChatTemplateKwargs , key , None )
1324- value = kwargs .pop (key , processor_value )
1325- chat_template_kwargs [ key ] = value
1330+ mm_load_kwargs = {}
1331+ for mm_load_key in ChatTemplateLoadKwargs .__annotations__ .keys ():
1332+ default_value = getattr (ChatTemplateLoadKwargs , mm_load_key , None )
1333+ value = kwargs .pop (mm_load_key , default_value )
1334+ mm_load_kwargs [ mm_load_key ] = value
13261335
13271336 if isinstance (conversation , (list , tuple )) and (
13281337 isinstance (conversation [0 ], (list , tuple )) or hasattr (conversation [0 ], "content" )
@@ -1333,13 +1342,8 @@ def apply_chat_template(
13331342 is_batched = False
13341343 conversations = [conversation ]
13351344
1336- num_frames = chat_template_kwargs .get ("num_frames" )
1337- video_fps = chat_template_kwargs .get ("video_fps" )
1338- video_load_backend = chat_template_kwargs .get ("video_load_backend" )
1339- tokenize = chat_template_kwargs .get ("tokenize" )
1340- return_dict = chat_template_kwargs .get ("return_dict" )
1341- sample_indices_fn = chat_template_kwargs .get ("sample_indices_fn" )
1342- sampling_rate = chat_template_kwargs .pop ("sampling_rate" )
1345+ tokenize = kwargs .pop ("tokenize" , False )
1346+ return_dict = kwargs .pop ("return_dict" , False )
13431347
13441348 if tokenize :
13451349 batch_images , batch_videos = [], []
@@ -1369,31 +1373,37 @@ def apply_chat_template(
13691373 if key in vision_info and vision_info ["type" ] == "video"
13701374 ]
13711375
1372- # Audio models do not accept nested list of audios (yet!)
1373- for fname in audio_fnames :
1374- batch_audios .append (load_audio (fname , sampling_rate = sampling_rate ))
13751376 for fname in image_fnames :
13761377 images .append (load_image (fname ))
1377- for fname in video_fnames :
1378- if isinstance (fname , (list , tuple )) and isinstance (fname [0 ], str ):
1379- video = [np .array (load_image (image_fname )).T for image_fname in fname ]
1380- # create a 4D video because `load_video` always returns a 4D array
1381- video = np .stack (video )
1382- metadata = None
1383- logger .warning (
1384- "When loading the video from list of images, we cannot infer metadata such as `fps` or `duration`. "
1385- "If you model applies special processing based on metadata, please load the whole video and let the model sample frames."
1386- )
1387- else :
1388- video , metadata = load_video (
1389- fname ,
1390- num_frames = num_frames ,
1391- fps = video_fps ,
1392- backend = video_load_backend ,
1393- sample_indices_fn = sample_indices_fn ,
1394- )
1395- videos .append (video )
1396- video_metadata .append (metadata )
1378+
1379+ # Audio models do not accept nested list of audios (yet!) so we construct a flat input audio list
1380+ if not mm_load_kwargs ["load_audio_from_video" ]:
1381+ for fname in audio_fnames :
1382+ batch_audios .append (load_audio (fname , sampling_rate = mm_load_kwargs ["sampling_rate" ]))
1383+ else :
1384+ for fname in video_fnames :
1385+ if isinstance (fname , (list , tuple )) and isinstance (fname [0 ], str ):
1386+ video = [np .array (load_image (image_fname )).T for image_fname in fname ]
1387+ # create a 4D video because `load_video` always returns a 4D array
1388+ video = np .stack (video )
1389+ metadata = None
1390+ audios = None
1391+ logger .warning (
1392+ "When loading the video from list of images, we cannot infer metadata such as `fps` or `duration`. "
1393+ "If your model uses this metadata during processing, please load the whole video and let the model sample frames instead."
1394+ )
1395+ else :
1396+ video , metadata = load_video (
1397+ fname ,
1398+ num_frames = mm_load_kwargs ["num_frames" ],
1399+ fps = mm_load_kwargs ["video_fps" ],
1400+ backend = mm_load_kwargs ["video_load_backend" ],
1401+ sample_indices_fn = mm_load_kwargs ["sample_indices_fn" ],
1402+ )
1403+ audios = load_audio (fname , sampling_rate = mm_load_kwargs ["sampling_rate" ])
1404+ batch_audios .append (audios )
1405+ videos .append (video )
1406+ video_metadata .append (metadata )
13971407
13981408 # Currently all processors can accept nested list of batches, but not flat list of visuals
13991409 # So we'll make a batched list of images and let the processor handle it
@@ -1409,7 +1419,7 @@ def apply_chat_template(
14091419 batch_images = batch_images ,
14101420 batch_videos = batch_videos ,
14111421 batch_video_metadata = batch_video_metadata ,
1412- ** chat_template_kwargs ,
1422+ ** mm_load_kwargs ,
14131423 )
14141424
14151425 prompt = self .tokenizer .apply_chat_template (
@@ -1438,7 +1448,7 @@ def apply_chat_template(
14381448 text = prompt ,
14391449 images = batch_images if batch_images else None ,
14401450 videos = batch_videos if batch_videos else None ,
1441- audios = batch_audios if batch_audios else None ,
1451+ audio = batch_audios if batch_audios else None ,
14421452 ** kwargs ,
14431453 )
14441454 if return_dict :
0 commit comments