@@ -976,10 +976,12 @@ def _process_image_input(
976976 image_embeds = self .visual (pixel_values , grid_thw = grid_thw_list )
977977
978978 # Split concatenated embeddings for each image item.
979+ # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
979980 merge_size = self .visual .spatial_merge_size
980- sizes = grid_thw .prod (- 1 ) // merge_size // merge_size
981+ sizes = (torch .tensor (grid_thw_list , dtype = torch .long ).prod (- 1 ) //
982+ (merge_size * merge_size )).tolist ()
981983
982- return image_embeds .split (sizes . tolist () )
984+ return image_embeds .split (sizes )
983985
984986 def _process_video_input (
985987 self ,
@@ -998,9 +1000,11 @@ def _process_video_input(
9981000
9991001 # Split concatenated embeddings for each video item.
10001002 merge_size = self .visual .spatial_merge_size
1001- sizes = grid_thw .prod (- 1 ) // merge_size // merge_size
1003+ # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
1004+ sizes = (torch .tensor (grid_thw_list , dtype = torch .long ).prod (- 1 ) //
1005+ (merge_size * merge_size )).tolist ()
10021006
1003- return video_embeds .split (sizes . tolist () )
1007+ return video_embeds .split (sizes )
10041008
10051009 def _parse_and_validate_multimodal_inputs (self , ** kwargs : object ) -> dict :
10061010 mm_input_by_modality = {}
0 commit comments