Skip to content

Commit eebcfdb

Browse files
chaunceyjiangsumitd2
authored andcommitted
[Frontend] Multi-Modality Support for Loading Local Image Files (vllm-project#9915)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Signed-off-by: Sumit Dubey <sumit.dubey2@ibm.com>
1 parent a7d6697 commit eebcfdb

File tree

6 files changed

+132
-14
lines changed

6 files changed

+132
-14
lines changed

tests/multimodal/test_utils.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import base64
22
import mimetypes
3-
from tempfile import NamedTemporaryFile
3+
import os
4+
from tempfile import NamedTemporaryFile, TemporaryDirectory
45
from typing import Dict, Tuple
56

67
import numpy as np
78
import pytest
8-
from PIL import Image
9+
from PIL import Image, ImageChops
910
from transformers import AutoConfig, AutoTokenizer
1011

1112
from vllm.multimodal.utils import (async_fetch_image, fetch_image,
@@ -84,6 +85,40 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
8485
assert _image_equals(data_image_sync, data_image_async)
8586

8687

88+
@pytest.mark.asyncio
89+
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
90+
async def test_fetch_image_local_files(image_url: str):
91+
with TemporaryDirectory() as temp_dir:
92+
origin_image = fetch_image(image_url)
93+
origin_image.save(os.path.join(temp_dir, os.path.basename(image_url)),
94+
quality=100,
95+
icc_profile=origin_image.info.get('icc_profile'))
96+
97+
image_async = await async_fetch_image(
98+
f"file://{temp_dir}/{os.path.basename(image_url)}",
99+
allowed_local_media_path=temp_dir)
100+
101+
image_sync = fetch_image(
102+
f"file://{temp_dir}/{os.path.basename(image_url)}",
103+
allowed_local_media_path=temp_dir)
104+
# Check that the images are equal
105+
assert not ImageChops.difference(image_sync, image_async).getbbox()
106+
107+
with pytest.raises(ValueError):
108+
await async_fetch_image(
109+
f"file://{temp_dir}/../{os.path.basename(image_url)}",
110+
allowed_local_media_path=temp_dir)
111+
with pytest.raises(ValueError):
112+
await async_fetch_image(
113+
f"file://{temp_dir}/../{os.path.basename(image_url)}")
114+
115+
with pytest.raises(ValueError):
116+
fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}",
117+
allowed_local_media_path=temp_dir)
118+
with pytest.raises(ValueError):
119+
fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}")
120+
121+
87122
@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"])
88123
def test_repeat_and_pad_placeholder_tokens(model):
89124
config = AutoConfig.from_pretrained(model)

vllm/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ class ModelConfig:
5555
"mistral" will always use the tokenizer from `mistral_common`.
5656
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
5757
downloading the model and tokenizer.
58+
allowed_local_media_path: Allowing API requests to read local images or
59+
videos from directories specified by the server file system.
60+
This is a security risk. Should only be enabled in trusted
61+
environments.
5862
dtype: Data type for model weights and activations. The "auto" option
5963
will use FP16 precision for FP32 and FP16 models, and BF16 precision
6064
for BF16 models.
@@ -134,6 +138,7 @@ def __init__(
134138
trust_remote_code: bool,
135139
dtype: Union[str, torch.dtype],
136140
seed: int,
141+
allowed_local_media_path: str = "",
137142
revision: Optional[str] = None,
138143
code_revision: Optional[str] = None,
139144
rope_scaling: Optional[dict] = None,
@@ -164,6 +169,7 @@ def __init__(
164169
self.tokenizer = tokenizer
165170
self.tokenizer_mode = tokenizer_mode
166171
self.trust_remote_code = trust_remote_code
172+
self.allowed_local_media_path = allowed_local_media_path
167173
self.seed = seed
168174
self.revision = revision
169175
self.code_revision = code_revision
@@ -1319,6 +1325,8 @@ def maybe_create_spec_config(
13191325
tokenizer=target_model_config.tokenizer,
13201326
tokenizer_mode=target_model_config.tokenizer_mode,
13211327
trust_remote_code=target_model_config.trust_remote_code,
1328+
allowed_local_media_path=target_model_config.
1329+
allowed_local_media_path,
13221330
dtype=target_model_config.dtype,
13231331
seed=target_model_config.seed,
13241332
revision=draft_revision,

vllm/engine/arg_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ class EngineArgs:
9292
tokenizer_mode: str = 'auto'
9393
chat_template_text_format: str = 'string'
9494
trust_remote_code: bool = False
95+
allowed_local_media_path: str = ""
9596
download_dir: Optional[str] = None
9697
load_format: str = 'auto'
9798
config_format: ConfigFormat = ConfigFormat.AUTO
@@ -269,6 +270,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
269270
parser.add_argument('--trust-remote-code',
270271
action='store_true',
271272
help='Trust remote code from huggingface.')
273+
parser.add_argument(
274+
'--allowed-local-media-path',
275+
type=str,
276+
help="Allowing API requests to read local images or videos"
277+
"from directories specified by the server file system."
278+
"This is a security risk."
279+
"Should only be enabled in trusted environments")
272280
parser.add_argument('--download-dir',
273281
type=nullable_str,
274282
default=EngineArgs.download_dir,
@@ -920,6 +928,7 @@ def create_model_config(self) -> ModelConfig:
920928
tokenizer_mode=self.tokenizer_mode,
921929
chat_template_text_format=self.chat_template_text_format,
922930
trust_remote_code=self.trust_remote_code,
931+
allowed_local_media_path=self.allowed_local_media_path,
923932
dtype=self.dtype,
924933
seed=self.seed,
925934
revision=self.revision,

vllm/entrypoints/chat_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,9 @@ def __init__(self, tracker: MultiModalItemTracker) -> None:
307307
self._tracker = tracker
308308

309309
def parse_image(self, image_url: str) -> None:
310-
image = get_and_parse_image(image_url)
310+
image = get_and_parse_image(image_url,
311+
allowed_local_media_path=self._tracker.
312+
_model_config.allowed_local_media_path)
311313

312314
placeholder = self._tracker.add("image", image)
313315
self._add_placeholder(placeholder)
@@ -327,7 +329,10 @@ def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
327329
self._tracker = tracker
328330

329331
def parse_image(self, image_url: str) -> None:
330-
image_coro = async_get_and_parse_image(image_url)
332+
image_coro = async_get_and_parse_image(
333+
image_url,
334+
allowed_local_media_path=self._tracker._model_config.
335+
allowed_local_media_path)
331336

332337
placeholder = self._tracker.add("image", image_coro)
333338
self._add_placeholder(placeholder)

vllm/entrypoints/llm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ class LLM:
5858
from the input.
5959
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
6060
downloading the model and tokenizer.
61+
allowed_local_media_path: Allowing API requests to read local images
62+
or videos from directories specified by the server file system.
63+
This is a security risk. Should only be enabled in trusted
64+
environments.
6165
tensor_parallel_size: The number of GPUs to use for distributed
6266
execution with tensor parallelism.
6367
dtype: The data type for the model weights and activations. Currently,
@@ -139,6 +143,7 @@ def __init__(
139143
tokenizer_mode: str = "auto",
140144
skip_tokenizer_init: bool = False,
141145
trust_remote_code: bool = False,
146+
allowed_local_media_path: str = "",
142147
tensor_parallel_size: int = 1,
143148
dtype: str = "auto",
144149
quantization: Optional[str] = None,
@@ -179,6 +184,7 @@ def __init__(
179184
tokenizer_mode=tokenizer_mode,
180185
skip_tokenizer_init=skip_tokenizer_init,
181186
trust_remote_code=trust_remote_code,
187+
allowed_local_media_path=allowed_local_media_path,
182188
tensor_parallel_size=tensor_parallel_size,
183189
dtype=dtype,
184190
quantization=quantization,

vllm/multimodal/utils.py

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import base64
2+
import os
23
from functools import lru_cache
34
from io import BytesIO
45
from typing import Any, List, Optional, Tuple, TypeVar, Union
@@ -18,19 +19,60 @@
1819
cached_get_tokenizer = lru_cache(get_tokenizer)
1920

2021

21-
def _load_image_from_bytes(b: bytes):
22+
def _load_image_from_bytes(b: bytes) -> Image.Image:
2223
image = Image.open(BytesIO(b))
2324
image.load()
2425
return image
2526

2627

27-
def _load_image_from_data_url(image_url: str):
28+
def _is_subpath(image_path: str, allowed_local_media_path: str) -> bool:
29+
# Get the common path
30+
common_path = os.path.commonpath([
31+
os.path.abspath(image_path),
32+
os.path.abspath(allowed_local_media_path)
33+
])
34+
# Check if the common path is the same as allowed_local_media_path
35+
return common_path == os.path.abspath(allowed_local_media_path)
36+
37+
38+
def _load_image_from_file(image_url: str,
39+
allowed_local_media_path: str) -> Image.Image:
40+
if not allowed_local_media_path:
41+
raise ValueError("Invalid 'image_url': Cannot load local files without"
42+
"'--allowed-local-media-path'.")
43+
if allowed_local_media_path:
44+
if not os.path.exists(allowed_local_media_path):
45+
raise ValueError(
46+
"Invalid '--allowed-local-media-path': "
47+
f"The path {allowed_local_media_path} does not exist.")
48+
if not os.path.isdir(allowed_local_media_path):
49+
raise ValueError(
50+
"Invalid '--allowed-local-media-path': "
51+
f"The path {allowed_local_media_path} must be a directory.")
52+
53+
# Only split once and assume the second part is the image path
54+
_, image_path = image_url.split("file://", 1)
55+
if not _is_subpath(image_path, allowed_local_media_path):
56+
raise ValueError(
57+
f"Invalid 'image_url': The file path {image_path} must"
58+
" be a subpath of '--allowed-local-media-path'"
59+
f" '{allowed_local_media_path}'.")
60+
61+
image = Image.open(image_path)
62+
image.load()
63+
return image
64+
65+
66+
def _load_image_from_data_url(image_url: str) -> Image.Image:
2867
# Only split once and assume the second part is the base64 encoded image
2968
_, image_base64 = image_url.split(",", 1)
3069
return load_image_from_base64(image_base64)
3170

3271

33-
def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
72+
def fetch_image(image_url: str,
73+
*,
74+
image_mode: str = "RGB",
75+
allowed_local_media_path: str = "") -> Image.Image:
3476
"""
3577
Load a PIL image from a HTTP or base64 data URL.
3678
@@ -43,16 +85,19 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
4385

4486
elif image_url.startswith('data:image'):
4587
image = _load_image_from_data_url(image_url)
88+
elif image_url.startswith('file://'):
89+
image = _load_image_from_file(image_url, allowed_local_media_path)
4690
else:
4791
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
48-
"with either 'data:image' or 'http'.")
92+
"with either 'data:image', 'file://' or 'http'.")
4993

5094
return image.convert(image_mode)
5195

5296

5397
async def async_fetch_image(image_url: str,
5498
*,
55-
image_mode: str = "RGB") -> Image.Image:
99+
image_mode: str = "RGB",
100+
allowed_local_media_path: str = "") -> Image.Image:
56101
"""
57102
Asynchronously load a PIL image from a HTTP or base64 data URL.
58103
@@ -65,9 +110,11 @@ async def async_fetch_image(image_url: str,
65110

66111
elif image_url.startswith('data:image'):
67112
image = _load_image_from_data_url(image_url)
113+
elif image_url.startswith('file://'):
114+
image = _load_image_from_file(image_url, allowed_local_media_path)
68115
else:
69116
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
70-
"with either 'data:image' or 'http'.")
117+
"with either 'data:image', 'file://' or 'http'.")
71118

72119
return image.convert(image_mode)
73120

@@ -126,8 +173,12 @@ def get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
126173
return {"audio": (audio, sr)}
127174

128175

129-
def get_and_parse_image(image_url: str) -> MultiModalDataDict:
130-
image = fetch_image(image_url)
176+
def get_and_parse_image(
177+
image_url: str,
178+
*,
179+
allowed_local_media_path: str = "") -> MultiModalDataDict:
180+
image = fetch_image(image_url,
181+
allowed_local_media_path=allowed_local_media_path)
131182
return {"image": image}
132183

133184

@@ -136,8 +187,12 @@ async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
136187
return {"audio": (audio, sr)}
137188

138189

139-
async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
140-
image = await async_fetch_image(image_url)
190+
async def async_get_and_parse_image(
191+
image_url: str,
192+
*,
193+
allowed_local_media_path: str = "") -> MultiModalDataDict:
194+
image = await async_fetch_image(
195+
image_url, allowed_local_media_path=allowed_local_media_path)
141196
return {"image": image}
142197

143198

0 commit comments

Comments
 (0)