diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index e55b8bbfdeaac..8691a61343ab6 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -35,8 +35,12 @@ def _load_image_from_data_url(image_url: str): return load_image_from_base64(image_base64) -def fetch_image(image_url: str) -> Image.Image: - """Load PIL image from a url or base64 encoded openai GPT4V format""" +def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image: + """ + Load a PIL image from a HTTP or base64 data URL. + + By default, the image is converted into RGB format. + """ if image_url.startswith('http'): _validate_remote_url(image_url, name="image_url") @@ -53,7 +57,7 @@ def fetch_image(image_url: str) -> Image.Image: raise ValueError("Invalid 'image_url': A valid 'image_url' must start " "with either 'data:image' or 'http'.") - return image + return image.convert(image_mode) class ImageFetchAiohttp: @@ -70,8 +74,17 @@ def get_aiohttp_client(cls) -> aiohttp.ClientSession: return cls.aiohttp_client @classmethod - async def fetch_image(cls, image_url: str) -> Image.Image: - """Load PIL image from a url or base64 encoded openai GPT4V format""" + async def fetch_image( + cls, + image_url: str, + *, + image_mode: str = "RGB", + ) -> Image.Image: + """ + Asynchronously load a PIL image from a HTTP or base64 data URL. + + By default, the image is converted into RGB format. + """ if image_url.startswith('http'): _validate_remote_url(image_url, name="image_url") @@ -91,7 +104,7 @@ async def fetch_image(cls, image_url: str) -> Image.Image: "Invalid 'image_url': A valid 'image_url' must start " "with either 'data:image' or 'http'.") - return image + return image.convert(image_mode) async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict: @@ -99,12 +112,19 @@ async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict: return {"image": image} -def encode_image_base64(image: Image.Image, format: str = 'JPEG') -> str: - """Encode a pillow image to base64 format.""" +def encode_image_base64( + image: Image.Image, + *, + image_mode: str = "RGB", + format: str = "JPEG", +) -> str: + """ + Encode a pillow image to base64 format. + By default, the image is converted into RGB format before being encoded. + """ buffered = BytesIO() - if format == 'JPEG': - image = image.convert('RGB') + image = image.convert(image_mode) image.save(buffered, format) return base64.b64encode(buffered.getvalue()).decode('utf-8')