diff --git a/lmms_eval/models/gemini_api.py b/lmms_eval/models/gemini_api.py index 69f520c1..3fb1d39f 100644 --- a/lmms_eval/models/gemini_api.py +++ b/lmms_eval/models/gemini_api.py @@ -1,6 +1,8 @@ import io import json import os +import pathlib +import re import time from typing import List, Tuple @@ -39,8 +41,9 @@ def __init__( model_version: str = "gemini-1.5-pro", # modality: str = "image", timeout: int = 120, - continual_mode: bool = False, + continual_mode: bool = True, response_persistent_folder: str = "./logs/gemini_persistent_folder", + interleave: bool = False, # We will cache the Gemini API response in this path and use it for future requests **kwargs, ) -> None: @@ -49,6 +52,7 @@ def __init__( self.timeout = timeout self.model = genai.GenerativeModel(model_version) self.continual_mode = continual_mode + self.interleave = interleave # if self.continual_mode and response_persistent_folder is None: # raise ValueError("Continual mode requires a persistent path for the response. We will cache the Gemini API response in this path and use it for future requests. Please provide a valid path.") if self.continual_mode: @@ -132,6 +136,20 @@ def convert_modality(self, images): eval_logger.error(f"Error converting video: {str(e)}") return images + def construct_interleaved_input(self, content, media): + pattern = r"" + parts = re.split(pattern, content) + result = [] + for i, part in enumerate(parts): + if i % 2 == 0: + if part == "": + continue + result.append(part) + else: + result.append(media[int(part)]) + + return result + def generate_until(self, requests) -> List[str]: res = [] pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") @@ -163,7 +181,10 @@ def get_uuid(task, split, doc_id): visuals = self.flatten(visuals) visuals = self.convert_modality(visuals) - message = [contexts] + visuals + if self.interleave: + message = self.construct_interleaved_input(contexts, visuals) + else: + message = [contexts] + visuals for attempt in range(5): try: @@ -213,3 +234,72 @@ def generate_until_multi_round(self, requests) -> List[str]: def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: # TODO assert False, "Gemini API not support" + + def get_image_audio_text_interleaved_messsage(self, image_path, audio_path, question): + # image_path for list of image path + # audio_path for list of audio path + # question for question + + # fixed image token and no audio in text + for index in range(1, 1 + len(image_path)): + question = question.replace(f"[img{index}]", "") + for index in range(1, 1 + len(audio_path)): + question = question.replace(f"[audio{index}]", "