Skip to content

Commit 6c68c29

Browse files
authored
Fix grid size calculation in qwen2.5-vl vision encoder warmup (#1004)
Signed-off-by: Kewei Wang <keweiwang@google.com>
1 parent 8665c45 commit 6c68c29

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

tpu_inference/models/jax/qwen2_5_vl.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1077,12 +1077,15 @@ def precompile_vision_encoder(
10771077
image_shapes = warmup_config.get("image_shapes")
10781078

10791079
vc = self.vllm_config.model_config.hf_config.vision_config
1080+
factor = vc.patch_size * vc.spatial_merge_size
10801081
for input_hw in image_shapes:
10811082
if not isinstance(input_hw, list) or len(input_hw) != 2:
10821083
logger.warning(f"Skipping invalid shape {input_hw}.")
10831084
continue
10841085
h_input, w_input = input_hw
1085-
t, h, w = 1, h_input // vc.patch_size, w_input // vc.patch_size
1086+
h_processed = round(h_input / factor) * factor
1087+
w_processed = round(w_input / factor) * factor
1088+
t, h, w = 1, h_processed // vc.patch_size, w_processed // vc.patch_size
10861089
grid_thw = (t, h, w)
10871090
num_patches = t * h * w
10881091
patch_input_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size * vc.patch_size

0 commit comments

Comments
 (0)