@@ -56,6 +56,17 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False):
5656 """The type of the content part."""
5757
5858
59+ class ChatCompletionContentPartImageEmbedsParam (TypedDict , total = False ):
60+ image_embeds : Required [Union [str , dict [str , str ]]]
61+ """
62+ The image embeddings. It can be either:
63+ - A single base64 string.
64+ - A dictionary where each value is a base64 string.
65+ """
66+ type : Required [Literal ["image_embeds" ]]
67+ """The type of the content part."""
68+
69+
5970class VideoURL (TypedDict , total = False ):
6071 url : Required [str ]
6172 """
@@ -109,6 +120,7 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
109120 ChatCompletionContentPartInputAudioParam ,
110121 ChatCompletionContentPartVideoParam , ChatCompletionContentPartRefusalParam ,
111122 CustomChatCompletionContentSimpleImageParam ,
123+ ChatCompletionContentPartImageEmbedsParam ,
112124 CustomChatCompletionContentSimpleAudioParam ,
113125 CustomChatCompletionContentSimpleVideoParam , str ]
114126
@@ -350,7 +362,7 @@ def resolve_chat_template_content_format(
350362 return detected_format
351363
352364
353- ModalityStr = Literal ["image" , "audio" , "video" ]
365+ ModalityStr = Literal ["image" , "audio" , "video" , "image_embeds" ]
354366_T = TypeVar ("_T" )
355367
356368
@@ -391,7 +403,7 @@ def _placeholder_str(self, modality: ModalityStr,
391403 hf_config = self ._model_config .hf_config
392404 model_type = hf_config .model_type
393405
394- if modality == "image" :
406+ if modality in [ "image" , "image_embeds" ] :
395407 if model_type == "phi3_v" :
396408 # Workaround since this token is not defined in the tokenizer
397409 return f"<|image_{ current_count } |>"
@@ -470,10 +482,27 @@ def create_parser(self) -> "BaseMultiModalContentParser":
470482class MultiModalItemTracker (BaseMultiModalItemTracker [object ]):
471483
472484 def all_mm_data (self ) -> Optional [MultiModalDataDict ]:
473- if self ._items_by_modality :
474- return dict (self ._items_by_modality )
475-
476- return None
485+ if not self ._items_by_modality :
486+ return None
487+ mm_inputs = {}
488+ items_by_modality = dict (self ._items_by_modality )
489+ if "image" in items_by_modality and "image_embeds" in items_by_modality :
490+ raise ValueError (\
491+ "Mixing raw image and embedding inputs is not allowed" )
492+
493+ if "image_embeds" in items_by_modality :
494+ image_embeds_lst = items_by_modality ["image_embeds" ]
495+ if len (image_embeds_lst ) > 1 :
496+ raise ValueError (\
497+ "Only one message can have {'type': 'image_embeds'}" )
498+ mm_inputs ["image" ] = image_embeds_lst [0 ]
499+ elif "image" in items_by_modality :
500+ mm_inputs ["image" ] = items_by_modality ["image" ] # A list of images
501+ elif "audio" in items_by_modality :
502+ mm_inputs ["audio" ] = items_by_modality ["audio" ] # A list of audios
503+ elif "video" in items_by_modality :
504+ mm_inputs ["video" ] = items_by_modality ["video" ] # A list of videos
505+ return mm_inputs
477506
478507 def create_parser (self ) -> "BaseMultiModalContentParser" :
479508 return MultiModalContentParser (self )
@@ -482,13 +511,31 @@ def create_parser(self) -> "BaseMultiModalContentParser":
482511class AsyncMultiModalItemTracker (BaseMultiModalItemTracker [Awaitable [object ]]):
483512
484513 async def all_mm_data (self ) -> Optional [MultiModalDataDict ]:
485- if self ._items_by_modality :
486- return {
514+ if not self ._items_by_modality :
515+ return None
516+ mm_inputs = {}
517+ items_by_modality = {
487518 modality : await asyncio .gather (* items )
488519 for modality , items in self ._items_by_modality .items ()
489520 }
490521
491- return None
522+ if "image" in items_by_modality and "image_embeds" in items_by_modality :
523+ raise ValueError (
524+ "Mixing raw image and embedding inputs is not allowed" )
525+
526+ if "image_embeds" in items_by_modality :
527+ image_embeds_lst = items_by_modality ["image_embeds" ]
528+ if len (image_embeds_lst ) > 1 :
529+ raise ValueError (
530+ "Only one message can have {'type': 'image_embeds'}" )
531+ mm_inputs ["image" ] = image_embeds_lst [0 ]
532+ elif "image" in items_by_modality :
533+ mm_inputs ["image" ] = items_by_modality ["image" ] # A list of images
534+ elif "audio" in items_by_modality :
535+ mm_inputs ["audio" ] = items_by_modality ["audio" ] # A list of audios
536+ elif "video" in items_by_modality :
537+ mm_inputs ["video" ] = items_by_modality ["video" ] # A list of videos
538+ return mm_inputs
492539
493540 def create_parser (self ) -> "BaseMultiModalContentParser" :
494541 return AsyncMultiModalContentParser (self )
@@ -513,6 +560,11 @@ def mm_placeholder_counts(self) -> dict[str, int]:
513560 def parse_image (self , image_url : str ) -> None :
514561 raise NotImplementedError
515562
563+ @abstractmethod
564+ def parse_image_embeds (self ,
565+ image_embeds : Union [str , dict [str , str ]]) -> None :
566+ raise NotImplementedError
567+
516568 @abstractmethod
517569 def parse_audio (self , audio_url : str ) -> None :
518570 raise NotImplementedError
@@ -543,6 +595,21 @@ def parse_image(self, image_url: str) -> None:
543595 placeholder = self ._tracker .add ("image" , image )
544596 self ._add_placeholder (placeholder )
545597
598+ def parse_image_embeds (self ,
599+ image_embeds : Union [str , dict [str , str ]]) -> None :
600+ if isinstance (image_embeds , dict ):
601+ embeds = {
602+ k : self ._connector .fetch_image_embedding (v )
603+ for k , v in image_embeds .items ()
604+ }
605+ placeholder = self ._tracker .add ("image_embeds" , embeds )
606+
607+ if isinstance (image_embeds , str ):
608+ embedding = self ._connector .fetch_image_embedding (image_embeds )
609+ placeholder = self ._tracker .add ("image_embeds" , embedding )
610+
611+ self ._add_placeholder (placeholder )
612+
546613 def parse_audio (self , audio_url : str ) -> None :
547614 audio = self ._connector .fetch_audio (audio_url )
548615
@@ -579,6 +646,25 @@ def parse_image(self, image_url: str) -> None:
579646 placeholder = self ._tracker .add ("image" , image_coro )
580647 self ._add_placeholder (placeholder )
581648
649+ def parse_image_embeds (self ,
650+ image_embeds : Union [str , dict [str , str ]]) -> None :
651+ future : asyncio .Future [Union [str , dict [str , str ]]] = asyncio .Future ()
652+
653+ if isinstance (image_embeds , dict ):
654+ embeds = {
655+ k : self ._connector .fetch_image_embedding (v )
656+ for k , v in image_embeds .items ()
657+ }
658+ future .set_result (embeds )
659+
660+ if isinstance (image_embeds , str ):
661+ embedding = self ._connector .\
662+ fetch_image_embedding (image_embeds )
663+ future .set_result (embedding )
664+
665+ placeholder = self ._tracker .add ("image_embeds" , future )
666+ self ._add_placeholder (placeholder )
667+
582668 def parse_audio (self , audio_url : str ) -> None :
583669 audio_coro = self ._connector .fetch_audio_async (audio_url )
584670
@@ -684,6 +770,7 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
684770# No need to validate using Pydantic again
685771_TextParser = partial (cast , ChatCompletionContentPartTextParam )
686772_ImageParser = partial (cast , ChatCompletionContentPartImageParam )
773+ _ImageEmbedsParser = partial (cast , ChatCompletionContentPartImageEmbedsParam )
687774_AudioParser = partial (cast , ChatCompletionContentPartAudioParam )
688775_InputAudioParser = partial (cast , ChatCompletionContentPartInputAudioParam )
689776_RefusalParser = partial (cast , ChatCompletionContentPartRefusalParam )
@@ -700,6 +787,8 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
700787 lambda part : _TextParser (part ).get ("text" , "" ),
701788 "image_url" :
702789 lambda part : _ImageParser (part ).get ("image_url" , {}).get ("url" , "" ),
790+ "image_embeds" :
791+ lambda part : _ImageEmbedsParser (part ).get ("image_embeds" , {}),
703792 "audio_url" :
704793 lambda part : _AudioParser (part ).get ("audio_url" , {}).get ("url" , "" ),
705794 "input_audio" :
@@ -769,6 +858,7 @@ def _parse_chat_message_content_mm_part(
769858
770859
771860VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text" , "refusal" , "image_url" ,
861+ "image_embeds" ,
772862 "audio_url" , "input_audio" , "video_url" )
773863
774864
@@ -843,7 +933,10 @@ def _parse_chat_message_content_part(
843933 str_content = cast (str , content )
844934 mm_parser .parse_image (str_content )
845935 return {'type' : 'image' } if wrap_dicts else None
846-
936+ if part_type == "image_embeds" :
937+ content = cast (Union [str , dict [str , str ]], content )
938+ mm_parser .parse_image_embeds (content )
939+ return {'type' : 'image' } if wrap_dicts else None
847940 if part_type == "audio_url" :
848941 str_content = cast (str , content )
849942 mm_parser .parse_audio (str_content )
0 commit comments