| 
 | 1 | +from argparse import ArgumentParser  | 
 | 2 | +from vllm import LLM, EngineArgs, SamplingParams  | 
 | 3 | +from vllm.assets.image import ImageAsset  | 
 | 4 | +from vllm.assets.video import VideoAsset  | 
 | 5 | +from vllm.multimodal.image import convert_image_mode  | 
 | 6 | +from dataclasses import asdict  | 
 | 7 | +from typing import Union  | 
 | 8 | +from PIL import Image  | 
 | 9 | +from dataclasses import dataclass  | 
 | 10 | +import yaml  | 
 | 11 | +from vllm_gaudi.extension.logger import logger as init_logger  | 
 | 12 | + | 
 | 13 | +logger = init_logger()  | 
 | 14 | + | 
 | 15 | + | 
 | 16 | +@dataclass  | 
 | 17 | +class PROMPT_DATA:  | 
 | 18 | +    _questions = {  | 
 | 19 | +        "image": [  | 
 | 20 | +            "What is the most prominent object in this image?",  | 
 | 21 | +            "Describe the scene in the image.",  | 
 | 22 | +            "What is the weather like in the image?",  | 
 | 23 | +            "Write a short poem about this image."  | 
 | 24 | +        ],  | 
 | 25 | +        "video": [  | 
 | 26 | +            "Describe this video",  | 
 | 27 | +            "Which movie would you associate this video with?"  | 
 | 28 | +        ]  | 
 | 29 | +    }  | 
 | 30 | + | 
 | 31 | +    _data = {  | 
 | 32 | +        "image":  | 
 | 33 | +        lambda source: convert_image_mode(  | 
 | 34 | +            ImageAsset("cherry_blossom").pil_image  | 
 | 35 | +            if source == "default" else Image.open(source), "RGB"),  | 
 | 36 | +        "video":  | 
 | 37 | +        lambda source: VideoAsset(name="baby_reading"  | 
 | 38 | +                                  if source == "default" else source,  | 
 | 39 | +                                  num_frames=16).np_ndarrays  | 
 | 40 | +    }  | 
 | 41 | + | 
 | 42 | +    def __post_init__(self):  | 
 | 43 | +        self._questions = self._questions  | 
 | 44 | +        self._data = self._data  | 
 | 45 | + | 
 | 46 | +    def get_prompts(self,  | 
 | 47 | +                    modality: str = "image",  | 
 | 48 | +                    media_source: str = "default",  | 
 | 49 | +                    num_prompts: int = 1,  | 
 | 50 | +                    skip_vision_data=False):  | 
 | 51 | +        if modality == "image":  | 
 | 52 | +            pholder = "<|image_pad|>"  | 
 | 53 | +        elif modality == "video":  | 
 | 54 | +            pholder = "<|video_pad|>"  | 
 | 55 | +        else:  | 
 | 56 | +            raise ValueError(f"Unsupported modality: {modality}."  | 
 | 57 | +                             " Supported modality: [image, video]")  | 
 | 58 | +        questions = self._questions[modality]  | 
 | 59 | +        prompts = [  | 
 | 60 | +            ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"  | 
 | 61 | +             f"<|im_start|>user\n<|vision_start|>{pholder}<|vision_end|>"  | 
 | 62 | +             f"{question}<|im_end|>\n"  | 
 | 63 | +             "<|im_start|>assistant\n") for question in questions  | 
 | 64 | +        ]  | 
 | 65 | + | 
 | 66 | +        data = self._data[modality](media_source)  | 
 | 67 | +        inputs = [{  | 
 | 68 | +            "prompt": prompts[i % len(prompts)],  | 
 | 69 | +            "multi_modal_data": {  | 
 | 70 | +                modality: data  | 
 | 71 | +            },  | 
 | 72 | +        } if not skip_vision_data else {  | 
 | 73 | +            "prompt": questions[i % len(questions)],  | 
 | 74 | +        } for i in range(num_prompts)]  | 
 | 75 | + | 
 | 76 | +        return inputs  | 
 | 77 | + | 
 | 78 | + | 
 | 79 | +def run_model(model_name: str, inputs: Union[dict, list[dict]], modality: str,  | 
 | 80 | +              **extra_engine_args):  | 
 | 81 | +    # Default mm_processor_kwargs  | 
 | 82 | +    # mm_processor_kwargs={  | 
 | 83 | +    #    "min_pixels": 28 * 28,  | 
 | 84 | +    #    "max_pixels": 1280 * 28 * 28,  | 
 | 85 | +    #    "fps": 1,  | 
 | 86 | +    # }  | 
 | 87 | +    passed_mm_processor_kwargs = extra_engine_args.get("mm_processor_kwargs",  | 
 | 88 | +                                                       {})  | 
 | 89 | +    passed_mm_processor_kwargs.setdefault("min_pixels", 28 * 28)  | 
 | 90 | +    passed_mm_processor_kwargs.setdefault("max_pixels", 1280 * 28 * 28)  | 
 | 91 | +    passed_mm_processor_kwargs.setdefault("fps", 1)  | 
 | 92 | +    extra_engine_args.update(  | 
 | 93 | +        {"mm_processor_kwargs": passed_mm_processor_kwargs})  | 
 | 94 | + | 
 | 95 | +    extra_engine_args.setdefault("max_model_len", 32768)  | 
 | 96 | +    extra_engine_args.setdefault("max_num_seqs", 5)  | 
 | 97 | +    extra_engine_args.setdefault("limit_mm_per_prompt", {modality: 1})  | 
 | 98 | + | 
 | 99 | +    sampling_params = SamplingParams(  | 
 | 100 | +        temperature=0.0,  | 
 | 101 | +        max_tokens=64,  | 
 | 102 | +    )  | 
 | 103 | + | 
 | 104 | +    engine_args = EngineArgs(model=model_name, **extra_engine_args)  | 
 | 105 | + | 
 | 106 | +    engine_args = asdict(engine_args)  | 
 | 107 | +    llm = LLM(**engine_args)  | 
 | 108 | + | 
 | 109 | +    outputs = llm.generate(  | 
 | 110 | +        inputs,  | 
 | 111 | +        sampling_params=sampling_params,  | 
 | 112 | +        use_tqdm=False,  # Disable tqdm for CI tests  | 
 | 113 | +    )  | 
 | 114 | +    return outputs  | 
 | 115 | + | 
 | 116 | + | 
 | 117 | +def start_test(model_card_path: str):  | 
 | 118 | +    with open(model_card_path) as f:  | 
 | 119 | +        model_card = yaml.safe_load(f)  | 
 | 120 | + | 
 | 121 | +    model_name = model_card.get("model_name", "Qwen/Qwen2.5-VL-7B-Instruct")  | 
 | 122 | +    test_config = model_card.get("test_config", [])  | 
 | 123 | +    if not test_config:  | 
 | 124 | +        logger.warning("No test configurations found.")  | 
 | 125 | +        return  | 
 | 126 | + | 
 | 127 | +    for config in test_config:  | 
 | 128 | +        modality = "image"  # Ensure modality is always defined  | 
 | 129 | +        try:  | 
 | 130 | +            modality = config.get("modality", "image")  | 
 | 131 | +            extra_engine_args = config.get("extra_engine_args", {})  | 
 | 132 | +            input_data_config = config.get("input_data_config", {})  | 
 | 133 | +            num_prompts = input_data_config.get("num_prompts", 1)  | 
 | 134 | +            media_source = input_data_config.get("media_source", "default")  | 
 | 135 | + | 
 | 136 | +            logger.info(  | 
 | 137 | +                "================================================\n"  | 
 | 138 | +                "Running test with configs:\n"  | 
 | 139 | +                "modality: %(modality)s\n"  | 
 | 140 | +                "input_data_config: %(input_data_config)s\n"  | 
 | 141 | +                "extra_engine_args: %(extra_engine_args)s\n"  | 
 | 142 | +                "================================================",  | 
 | 143 | +                dict(modality=modality,  | 
 | 144 | +                     input_data_config=input_data_config,  | 
 | 145 | +                     extra_engine_args=extra_engine_args))  | 
 | 146 | + | 
 | 147 | +            data = PROMPT_DATA()  | 
 | 148 | +            inputs = data.get_prompts(modality=modality,  | 
 | 149 | +                                      media_source=media_source,  | 
 | 150 | +                                      num_prompts=num_prompts)  | 
 | 151 | + | 
 | 152 | +            logger.info(  | 
 | 153 | +                "*** Questions for modality %(modality)s: %(questions)s",  | 
 | 154 | +                dict(modality=modality, questions=data._questions[modality]))  | 
 | 155 | +            responses = run_model(model_name, inputs, modality,  | 
 | 156 | +                                  **extra_engine_args)  | 
 | 157 | +            for response in responses:  | 
 | 158 | +                print(f"{response.outputs[0].text}")  | 
 | 159 | +                print("=" * 80)  | 
 | 160 | +        except Exception as e:  | 
 | 161 | +            logger.error("Error during test with modality %(modality)s: %(e)s",  | 
 | 162 | +                         dict(modality=modality, e=e))  | 
 | 163 | + | 
 | 164 | +            raise  | 
 | 165 | + | 
 | 166 | + | 
 | 167 | +def main():  | 
 | 168 | +    parser = ArgumentParser()  | 
 | 169 | +    parser.add_argument("--model-card-path",  | 
 | 170 | +                        required=True,  | 
 | 171 | +                        help="Path to .yaml file describing model parameters")  | 
 | 172 | +    args = parser.parse_args()  | 
 | 173 | +    start_test(args.model_card_path)  | 
 | 174 | + | 
 | 175 | + | 
 | 176 | +if __name__ == "__main__":  | 
 | 177 | +    try:  | 
 | 178 | +        main()  | 
 | 179 | +    except Exception:  | 
 | 180 | +        import os  | 
 | 181 | +        import traceback  | 
 | 182 | +        print("An error occurred during generation:")  | 
 | 183 | +        traceback.print_exc()  | 
 | 184 | +        os._exit(1)  | 
0 commit comments