@@ -1429,6 +1429,7 @@ def _process_image_input(
14291429 self , image_input : Glm4vImageInputs ) -> tuple [torch .Tensor , ...]:
14301430 grid_thw = image_input ["image_grid_thw" ]
14311431 assert grid_thw .ndim == 2
1432+ grid_thw_list = grid_thw .tolist ()
14321433
14331434 if image_input ["type" ] == "image_embeds" :
14341435 image_embeds = image_input ["image_embeds" ].type (self .visual .dtype )
@@ -1443,13 +1444,15 @@ def _process_image_input(
14431444 image_embeds = self .visual (pixel_values ,
14441445 grid_thw = grid_thw .tolist ())
14451446 merge_size = self .visual .spatial_merge_size
1446- sizes = grid_thw .prod (- 1 ) // merge_size // merge_size
1447- return image_embeds .split (sizes .tolist ())
1447+ sizes = (torch .tensor (grid_thw_list , dtype = torch .long ).prod (- 1 ) //
1448+ (merge_size * merge_size )).tolist ()
1449+ return image_embeds .split (sizes )
14481450
14491451 def _process_video_input (
14501452 self , video_input : Glm4vVideoInputs ) -> tuple [torch .Tensor , ...]:
14511453 grid_thw = video_input ["video_grid_thw" ]
14521454 assert grid_thw .ndim == 2
1455+ grid_thw_list = grid_thw .tolist ()
14531456
14541457 if video_input ["type" ] == "video_embeds" :
14551458 video_embeds = video_input ["video_embeds" ].type (self .visual .dtype )
@@ -1466,8 +1469,9 @@ def _process_video_input(
14661469 grid_thw = grid_thw .tolist ())
14671470 # Split concatenated embeddings for each video item.
14681471 merge_size = self .visual .spatial_merge_size
1469- sizes = grid_thw .prod (- 1 ) // merge_size // merge_size
1470- return video_embeds .split (sizes .tolist ())
1472+ sizes = (torch .tensor (grid_thw_list , dtype = torch .long ).prod (- 1 ) //
1473+ (merge_size * merge_size )).tolist ()
1474+ return video_embeds .split (sizes )
14711475
14721476 def _parse_and_validate_multimodal_inputs (self , ** kwargs : object ) -> dict :
14731477 mm_input_by_modality = {}
0 commit comments