Skip to content

Commit ba214df

Browse files
[Bugfix] Fix precision error in LLaVA-NeXT (#11735)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent eed11eb commit ba214df

File tree

3 files changed

+26
-14
lines changed

3 files changed

+26
-14
lines changed

tests/models/decoder_only/vision_language/processing/test_llava_next.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@ def processor_for_llava_next():
1515
return LlavaNextMultiModalProcessor
1616

1717

18-
# FIXME: image_size [(198, 176), (176, 198)]
1918
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
2019
@pytest.mark.parametrize("image_size", [(1669, 2560), (2560, 1669), (183, 488),
21-
(488, 183)])
20+
(488, 183), (198, 176), (176, 198)])
2221
@pytest.mark.parametrize("num_imgs", [1, 2])
2322
def test_processor_prompt_replacements(
2423
processor_for_llava_next,

vllm/model_executor/models/llava_next.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
33
TypedDict, Union)
44

5+
import numpy as np
56
import torch
67
import torch.nn as nn
78
from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor
@@ -139,16 +140,21 @@ def _get_num_unpadded_features(
139140
current_height = npatches * num_patch_height
140141
current_width = npatches * num_patch_width
141142

142-
original_aspect_ratio = original_width / original_height
143-
current_aspect_ratio = current_width / current_height
143+
# NOTE: HF resizes based on float32
144+
original_aspect_ratio = np.array(original_width / original_height,
145+
dtype=np.float32)
146+
current_aspect_ratio = np.array(current_width / current_height,
147+
dtype=np.float32)
144148

145149
if original_aspect_ratio > current_aspect_ratio:
146-
scale_factor = current_width / original_width
150+
scale_factor = np.array(current_width / original_width,
151+
dtype=np.float32)
147152
new_height = int(original_height * scale_factor)
148153
padding = (current_height - new_height) // 2
149154
current_height -= 2 * padding
150155
else:
151-
scale_factor = current_height / original_height
156+
scale_factor = np.array(current_height / original_height,
157+
dtype=np.float32)
152158
new_width = int(original_width * scale_factor)
153159
padding = (current_width - new_width) // 2
154160
current_width -= 2 * padding

vllm/model_executor/models/llava_onevision.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
44
TypedDict, Union)
55

6+
import numpy as np
67
import torch
78
import torch.nn as nn
89
from transformers import (BatchFeature, LlavaOnevisionConfig,
@@ -127,18 +128,24 @@ def _get_num_unpadded_features(
127128
current_height = npatches * num_patch_height
128129
current_width = npatches * num_patch_width
129130

130-
original_aspect_ratio = original_width / original_height
131-
current_aspect_ratio = current_width / current_height
131+
# NOTE: HF resizes based on float32
132+
original_aspect_ratio = np.array(original_width / original_height,
133+
dtype=np.float32)
134+
current_aspect_ratio = np.array(current_width / current_height,
135+
dtype=np.float32)
136+
132137
if original_aspect_ratio > current_aspect_ratio:
133-
new_height = int(original_height *
134-
(current_width / original_width))
138+
scale_factor = np.array(current_width / original_width,
139+
dtype=np.float32)
140+
new_height = int(original_height * scale_factor)
135141
padding = (current_height - new_height) // 2
136-
current_height -= padding * 2
142+
current_height -= 2 * padding
137143
else:
138-
new_width = int(original_width *
139-
(current_height / original_height))
144+
scale_factor = np.array(current_height / original_height,
145+
dtype=np.float32)
146+
new_width = int(original_width * scale_factor)
140147
padding = (current_width - new_width) // 2
141-
current_width -= padding * 2
148+
current_width -= 2 * padding
142149

143150
unpadded_features = current_height * current_width
144151
newline_features = current_height

0 commit comments

Comments
 (0)