@@ -56,6 +56,17 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False):
5656    """The type of the content part.""" 
5757
5858
59+ class  ChatCompletionContentPartImageEmbedsParam (TypedDict , total = False ):
60+     image_embeds : Required [Union [str , dict [str , str ]]]
61+     """ 
62+     The image embeddings. It can be either: 
63+     - A single base64 string. 
64+     - A dictionary where each value is a base64 string. 
65+     """ 
66+     type : Required [Literal ["image_embeds" ]]
67+     """The type of the content part.""" 
68+ 
69+ 
5970class  VideoURL (TypedDict , total = False ):
6071    url : Required [str ]
6172    """ 
@@ -109,6 +120,7 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
109120    ChatCompletionContentPartInputAudioParam ,
110121    ChatCompletionContentPartVideoParam , ChatCompletionContentPartRefusalParam ,
111122    CustomChatCompletionContentSimpleImageParam ,
123+     ChatCompletionContentPartImageEmbedsParam ,
112124    CustomChatCompletionContentSimpleAudioParam ,
113125    CustomChatCompletionContentSimpleVideoParam , str ]
114126
@@ -350,7 +362,7 @@ def resolve_chat_template_content_format(
350362    return  detected_format 
351363
352364
353- ModalityStr  =  Literal ["image" , "audio" , "video" ]
365+ ModalityStr  =  Literal ["image" , "audio" , "video" ,  "image_embeds" ]
354366_T  =  TypeVar ("_T" )
355367
356368
@@ -391,7 +403,7 @@ def _placeholder_str(self, modality: ModalityStr,
391403        hf_config  =  self ._model_config .hf_config 
392404        model_type  =  hf_config .model_type 
393405
394-         if  modality  ==   "image" :
406+         if  modality  in  [ "image" ,  "image_embeds" ] :
395407            if  model_type  ==  "phi3_v" :
396408                # Workaround since this token is not defined in the tokenizer 
397409                return  f"<|image_{ current_count }  |>" 
@@ -470,10 +482,27 @@ def create_parser(self) -> "BaseMultiModalContentParser":
470482class  MultiModalItemTracker (BaseMultiModalItemTracker [object ]):
471483
472484    def  all_mm_data (self ) ->  Optional [MultiModalDataDict ]:
473-         if  self ._items_by_modality :
474-             return  dict (self ._items_by_modality )
475- 
476-         return  None 
485+         if  not  self ._items_by_modality :
486+             return  None 
487+         mm_inputs  =  {}
488+         items_by_modality  =  dict (self ._items_by_modality )
489+         if  "image"  in  items_by_modality  and  "image_embeds"  in  items_by_modality :
490+             raise  ValueError (\
491+                 "Mixing raw image and embedding inputs is not allowed" )
492+ 
493+         if  "image_embeds"  in  items_by_modality :
494+             image_embeds_lst  =  items_by_modality ["image_embeds" ]
495+             if  len (image_embeds_lst ) >  1 :
496+                 raise  ValueError (\
497+                     "Only one message can have {'type': 'image_embeds'}" )
498+             mm_inputs ["image" ] =  image_embeds_lst [0 ]
499+         elif  "image"  in  items_by_modality :
500+             mm_inputs ["image" ] =  items_by_modality ["image" ] # A list of images 
501+         elif  "audio"  in  items_by_modality :
502+             mm_inputs ["audio" ] =  items_by_modality ["audio" ] # A list of audios 
503+         elif  "video"  in  items_by_modality :
504+             mm_inputs ["video" ] =  items_by_modality ["video" ] # A list of videos 
505+         return  mm_inputs 
477506
478507    def  create_parser (self ) ->  "BaseMultiModalContentParser" :
479508        return  MultiModalContentParser (self )
@@ -482,13 +511,31 @@ def create_parser(self) -> "BaseMultiModalContentParser":
482511class  AsyncMultiModalItemTracker (BaseMultiModalItemTracker [Awaitable [object ]]):
483512
484513    async  def  all_mm_data (self ) ->  Optional [MultiModalDataDict ]:
485-         if  self ._items_by_modality :
486-             return  {
514+         if  not  self ._items_by_modality :
515+             return  None 
516+         mm_inputs  =  {}
517+         items_by_modality  =  {
487518                modality : await  asyncio .gather (* items )
488519                for  modality , items  in  self ._items_by_modality .items ()
489520            }
490521
491-         return  None 
522+         if  "image"  in  items_by_modality  and  "image_embeds"  in  items_by_modality :
523+             raise  ValueError (
524+                 "Mixing raw image and embedding inputs is not allowed" )
525+ 
526+         if  "image_embeds"  in  items_by_modality :
527+             image_embeds_lst  =  items_by_modality ["image_embeds" ]
528+             if  len (image_embeds_lst ) >  1 :
529+                 raise  ValueError (
530+                     "Only one message can have {'type': 'image_embeds'}" )
531+             mm_inputs ["image" ] =  image_embeds_lst [0 ]
532+         elif  "image"  in  items_by_modality :
533+             mm_inputs ["image" ] =  items_by_modality ["image" ] # A list of images 
534+         elif  "audio"  in  items_by_modality :
535+             mm_inputs ["audio" ] =  items_by_modality ["audio" ] # A list of audios 
536+         elif  "video"  in  items_by_modality :
537+             mm_inputs ["video" ] =  items_by_modality ["video" ] # A list of videos 
538+         return  mm_inputs 
492539
493540    def  create_parser (self ) ->  "BaseMultiModalContentParser" :
494541        return  AsyncMultiModalContentParser (self )
@@ -513,6 +560,11 @@ def mm_placeholder_counts(self) -> dict[str, int]:
513560    def  parse_image (self , image_url : str ) ->  None :
514561        raise  NotImplementedError 
515562
563+     @abstractmethod  
564+     def  parse_image_embeds (self ,
565+                            image_embeds : Union [str , dict [str , str ]]) ->  None :
566+         raise  NotImplementedError 
567+ 
516568    @abstractmethod  
517569    def  parse_audio (self , audio_url : str ) ->  None :
518570        raise  NotImplementedError 
@@ -543,6 +595,21 @@ def parse_image(self, image_url: str) -> None:
543595        placeholder  =  self ._tracker .add ("image" , image )
544596        self ._add_placeholder (placeholder )
545597
598+     def  parse_image_embeds (self ,
599+                            image_embeds : Union [str , dict [str , str ]]) ->  None :
600+         if  isinstance (image_embeds , dict ):
601+             embeds  =  {
602+                 k : self ._connector .fetch_image_embedding (v )
603+                 for  k , v  in  image_embeds .items ()
604+             }
605+             placeholder  =  self ._tracker .add ("image_embeds" , embeds )
606+ 
607+         if  isinstance (image_embeds , str ):
608+             embedding  =  self ._connector .fetch_image_embedding (image_embeds )
609+             placeholder  =  self ._tracker .add ("image_embeds" , embedding )
610+ 
611+         self ._add_placeholder (placeholder )
612+ 
546613    def  parse_audio (self , audio_url : str ) ->  None :
547614        audio  =  self ._connector .fetch_audio (audio_url )
548615
@@ -579,6 +646,25 @@ def parse_image(self, image_url: str) -> None:
579646        placeholder  =  self ._tracker .add ("image" , image_coro )
580647        self ._add_placeholder (placeholder )
581648
649+     def  parse_image_embeds (self ,
650+                            image_embeds : Union [str , dict [str , str ]]) ->  None :
651+         future : asyncio .Future [Union [str , dict [str , str ]]] =  asyncio .Future ()
652+ 
653+         if  isinstance (image_embeds , dict ):
654+             embeds  =  {
655+                 k : self ._connector .fetch_image_embedding (v )
656+                 for  k , v  in  image_embeds .items ()
657+             }
658+             future .set_result (embeds )
659+ 
660+         if  isinstance (image_embeds , str ):
661+             embedding  =  self ._connector .\
662+                 fetch_image_embedding (image_embeds )
663+             future .set_result (embedding )
664+ 
665+         placeholder  =  self ._tracker .add ("image_embeds" , future )
666+         self ._add_placeholder (placeholder )
667+ 
582668    def  parse_audio (self , audio_url : str ) ->  None :
583669        audio_coro  =  self ._connector .fetch_audio_async (audio_url )
584670
@@ -684,6 +770,7 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
684770# No need to validate using Pydantic again 
685771_TextParser  =  partial (cast , ChatCompletionContentPartTextParam )
686772_ImageParser  =  partial (cast , ChatCompletionContentPartImageParam )
773+ _ImageEmbedsParser  =  partial (cast , ChatCompletionContentPartImageEmbedsParam )
687774_AudioParser  =  partial (cast , ChatCompletionContentPartAudioParam )
688775_InputAudioParser  =  partial (cast , ChatCompletionContentPartInputAudioParam )
689776_RefusalParser  =  partial (cast , ChatCompletionContentPartRefusalParam )
@@ -700,6 +787,8 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
700787    lambda  part : _TextParser (part ).get ("text" , "" ),
701788    "image_url" :
702789    lambda  part : _ImageParser (part ).get ("image_url" , {}).get ("url" , "" ),
790+     "image_embeds" :
791+     lambda  part : _ImageEmbedsParser (part ).get ("image_embeds" , {}),
703792    "audio_url" :
704793    lambda  part : _AudioParser (part ).get ("audio_url" , {}).get ("url" , "" ),
705794    "input_audio" :
@@ -769,6 +858,7 @@ def _parse_chat_message_content_mm_part(
769858
770859
771860VALID_MESSAGE_CONTENT_MM_PART_TYPES  =  ("text" , "refusal" , "image_url" ,
861+                                        "image_embeds" ,
772862                                       "audio_url" , "input_audio" , "video_url" )
773863
774864
@@ -843,7 +933,10 @@ def _parse_chat_message_content_part(
843933        str_content  =  cast (str , content )
844934        mm_parser .parse_image (str_content )
845935        return  {'type' : 'image' } if  wrap_dicts  else  None 
846- 
936+     if  part_type  ==  "image_embeds" :
937+         content  =  cast (Union [str , dict [str , str ]], content )
938+         mm_parser .parse_image_embeds (content )
939+         return  {'type' : 'image' } if  wrap_dicts  else  None 
847940    if  part_type  ==  "audio_url" :
848941        str_content  =  cast (str , content )
849942        mm_parser .parse_audio (str_content )
0 commit comments