Skip to content

Commit 5b8e941

Browse files
committed
Add mm max num images arg to prevent error
Signed-off-by: breno.skuk <breno.skuk@hcompany.ai>
1 parent 66eac49 commit 5b8e941

File tree

1 file changed

+45
-10
lines changed

1 file changed

+45
-10
lines changed

vllm/benchmarks/datasets.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,7 @@ class RandomMultiModalDataset(RandomDataset):
507507
DEFAULT_HEIGHT = 224
508508
DEFAULT_WIDTH = 224
509509
DEFAULT_NUM_IMAGES = 1
510+
DEFAULT_LIMIT_IMAGES_PER_PROMPT = 255
510511
DEFAULT_NUM_IMAGES_RANGE_RATIO = 0.0
511512
DEFAULT_DIMENSION_RANGE_RATIO = 0.0
512513
DEFAULT_ENABLE_MULTIMODAL_CHAT = False
@@ -528,6 +529,7 @@ def generate_synthetic_image(self, width: int, height: int) -> Image.Image:
528529
def get_image_sampling_params(
529530
self,
530531
num_images_range_ratio: float,
532+
limit_images_per_prompt: int,
531533
dimension_range_ratio: float,
532534
width: int,
533535
height: int,
@@ -541,18 +543,33 @@ def get_image_sampling_params(
541543
"num_images_range_ratio must be < 1.0 to ensure a valid sampling "
542544
"range"
543545
)
544-
max_num_images = int(num_images * (1 + num_images_range_ratio))
546+
max_num_images = min(int(num_images * (1 + num_images_range_ratio)),
547+
limit_images_per_prompt)
545548
# ensure min num images is zero
546549
min_num_images = max(int(num_images * (1 - num_images_range_ratio)), 0)
550+
# assert min_num_images <= max_num_images
551+
assert min_num_images <= max_num_images, (
552+
"min_num_images must be <= max_num_images"
553+
)
547554
# Enforce dimension_range_ratio < 1
548555
assert dimension_range_ratio < 1.0, (
549556
"dimension_range_ratio must be < 1.0 to ensure a valid sampling "
550557
"range"
551558
)
552-
min_width = int(width * (1 - dimension_range_ratio))
559+
# Ensure min_width and min_height are at least 1 to prevent
560+
# sampling 0-sized images.
561+
min_width = max(1, int(width * (1 - dimension_range_ratio)))
553562
max_width = int(width * (1 + dimension_range_ratio))
554-
min_height = int(height * (1 - dimension_range_ratio))
563+
min_height = max(1, int(height * (1 - dimension_range_ratio)))
555564
max_height = int(height * (1 + dimension_range_ratio))
565+
566+
logger.info(
567+
"Sampling number of images from [%s, %s] and image dimensions from "
568+
"[%s, %s]x[%s, %s]",
569+
min_num_images, max_num_images, min_width, max_width, min_height,
570+
max_height,
571+
)
572+
556573
return (
557574
min_num_images,
558575
max_num_images,
@@ -593,6 +610,7 @@ def sample(
593610
input_len: int = RandomDataset.DEFAULT_INPUT_LEN,
594611
output_len: int = RandomDataset.DEFAULT_OUTPUT_LEN,
595612
num_images: int = DEFAULT_NUM_IMAGES,
613+
limit_images_per_prompt: int = DEFAULT_LIMIT_IMAGES_PER_PROMPT,
596614
num_images_range_ratio: float = DEFAULT_NUM_IMAGES_RANGE_RATIO,
597615
width: int = DEFAULT_WIDTH,
598616
height: int = DEFAULT_HEIGHT,
@@ -614,6 +632,7 @@ def sample(
614632
height: Image height in pixels
615633
num_images_range_ratio: Relative half-width of the sampling
616634
interval for number of images.
635+
limit_images_per_prompt: Maximum number of images per request
617636
dimension_range_ratio: Relative half-width of the sampling
618637
interval for image dimensions.
619638
enable_multimodal_chat: Whether to apply multimodal chat
@@ -641,6 +660,7 @@ def sample(
641660
max_height,
642661
) = self.get_image_sampling_params(
643662
num_images_range_ratio,
663+
limit_images_per_prompt,
644664
dimension_range_ratio,
645665
width,
646666
height,
@@ -680,19 +700,27 @@ def sample(
680700
)
681701
for width, height in image_dimensions_iterator
682702
])
683-
# Avoid changing the type of `prompt` from str to list[dict]
684-
request_prompt: Any = prompt
703+
685704
if enable_multimodal_chat:
686-
request_prompt = self.apply_multimodal_chat_transformation(
705+
# NOTE: This option is only provided for completeness given
706+
# that the serve.py benchmark currently does not use it.
707+
mm_chat_prompt: Any = prompt
708+
mm_chat_prompt = self.apply_multimodal_chat_transformation(
687709
prompt, mm_content)
688-
mm_requests.append(
689-
SampleRequest(
690-
prompt=request_prompt,
710+
sample_request = SampleRequest(
711+
prompt=mm_chat_prompt,
712+
prompt_len=total_input_len,
713+
expected_output_len=int(output_lens[i]),
714+
multi_modal_data=None,
715+
)
716+
else:
717+
sample_request = SampleRequest(
718+
prompt=prompt,
691719
prompt_len=total_input_len,
692720
expected_output_len=int(output_lens[i]),
693721
multi_modal_data=mm_content,
694722
)
695-
)
723+
mm_requests.append(sample_request)
696724
return mm_requests
697725

698726
# -----------------------------------------------------------------------------
@@ -903,6 +931,12 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
903931
default=RandomMultiModalDataset.DEFAULT_NUM_IMAGES,
904932
help="Number of images per request for random-mm dataset.",
905933
)
934+
random_mm_group.add_argument(
935+
"--random-mm-limit-images-per-request",
936+
type=int,
937+
default=RandomMultiModalDataset.DEFAULT_LIMIT_IMAGES_PER_PROMPT,
938+
help="Maximum number of images per request for random-mm dataset.",
939+
)
906940
random_mm_group.add_argument(
907941
"--random-mm-width",
908942
type=int,
@@ -1119,6 +1153,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
11191153
input_len=args.random_input_len,
11201154
output_len=args.random_output_len,
11211155
num_images=args.random_mm_images_per_request,
1156+
limit_images_per_prompt=args.random_mm_limit_images_per_request,
11221157
width=args.random_mm_width,
11231158
height=args.random_mm_height,
11241159
num_images_range_ratio=args.random_mm_images_per_request_range_ratio,

0 commit comments

Comments
 (0)