@@ -976,10 +976,13 @@ def _process_image_input(
976976 image_embeds = self .visual (pixel_values , grid_thw = grid_thw_list )
977977
978978 # Split concatenated embeddings for each image item.
979+ # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
979980 merge_size = self .visual .spatial_merge_size
980- sizes = grid_thw .prod (- 1 ) // merge_size // merge_size
981+ sizes = (
982+ torch .prod (torch .tensor (grid_thw_list , dtype = torch .long ), - 1 ) //
983+ merge_size // merge_size ).tolist ()
981984
982- return image_embeds .split (sizes . tolist () )
985+ return image_embeds .split (sizes )
983986
984987 def _process_video_input (
985988 self ,
@@ -998,9 +1001,12 @@ def _process_video_input(
9981001
9991002 # Split concatenated embeddings for each video item.
10001003 merge_size = self .visual .spatial_merge_size
1001- sizes = grid_thw .prod (- 1 ) // merge_size // merge_size
1004+ # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
1005+ sizes = (
1006+ torch .prod (torch .tensor (grid_thw_list , dtype = torch .long ), - 1 ) //
1007+ merge_size // merge_size ).tolist ()
10021008
1003- return video_embeds .split (sizes . tolist () )
1009+ return video_embeds .split (sizes )
10041010
10051011 def _parse_and_validate_multimodal_inputs (self , ** kwargs : object ) -> dict :
10061012 mm_input_by_modality = {}
0 commit comments