3636import torch .nn as nn
3737import torch .nn .functional as F
3838from einops import rearrange
39+ from packaging .version import Version
3940from transformers import BatchFeature
41+ from transformers import __version__ as TRANSFORMERS_VERSION
4042from transformers .models .glm4v .configuration_glm4v import Glm4vVisionConfig
4143from transformers .models .glm4v .image_processing_glm4v import (
4244 Glm4vImageProcessor , smart_resize )
@@ -1001,28 +1003,32 @@ def _get_video_second_idx(self, metadata: dict[str, Any],
10011003 max_frame_idx = meta_frames - 1
10021004 duration = metadata .get ("duration" ,
10031005 round (max_frame_idx / video_fps ) + 1 )
1004- if duration <= video_processor .max_duration :
1005- n = int (math .floor (duration * video_processor .fps ))
1006- frame_indices = [
1007- min (
1008- max_frame_idx ,
1009- int (math .ceil (i * video_fps / video_processor .fps )),
1010- ) for i in range (n )
1011- ]
1006+ do_sample_frames = metadata ["do_sample_frames" ]
1007+ if not do_sample_frames :
1008+ frame_indices = metadata ["frames_indices" ]
10121009 else :
1013- num_samples = int (video_processor .max_duration *
1014- video_processor .fps )
1015- if num_samples >= meta_frames :
1016- frame_indices = list (range (meta_frames ))
1017- else :
1018- target_seconds = np .linspace (0 ,
1019- duration ,
1020- num_samples ,
1021- endpoint = True )
1010+ if duration <= video_processor .max_duration :
1011+ n = int (math .floor (duration * video_processor .fps ))
10221012 frame_indices = [
1023- min (max_frame_idx , int (math .ceil (t * video_fps )))
1024- for t in target_seconds
1013+ min (
1014+ max_frame_idx ,
1015+ int (math .ceil (i * video_fps / video_processor .fps )),
1016+ ) for i in range (n )
10251017 ]
1018+ else :
1019+ num_samples = int (video_processor .max_duration *
1020+ video_processor .fps )
1021+ if num_samples >= meta_frames :
1022+ frame_indices = list (range (meta_frames ))
1023+ else :
1024+ target_seconds = np .linspace (0 ,
1025+ duration ,
1026+ num_samples ,
1027+ endpoint = True )
1028+ frame_indices = [
1029+ min (max_frame_idx , int (math .ceil (t * video_fps )))
1030+ for t in target_seconds
1031+ ]
10261032
10271033 seen , uniq = set (), []
10281034 for idx in frame_indices :
@@ -1139,7 +1145,9 @@ def _get_dummy_videos(
11391145 "fps" : 2.0 ,
11401146 "duration" : num_frames / 2.0 ,
11411147 "total_num_frames" : num_frames ,
1148+ "frames_indices" : [i for i in range (num_frames )],
11421149 "video_backend" : "opencv" ,
1150+ "do_sample_frames" : False ,
11431151 }
11441152 video_item = (video .copy (), video_metadata )
11451153 video_items .append (video_item )
@@ -1172,34 +1180,37 @@ def _call_hf_processor(
11721180 for item in mm_data .pop ("videos" , []):
11731181 video_array , metadata = item
11741182
1175- if metadata ["video_backend" ] == "opencv_dynamic" :
1176- mm_kwargs ["do_sample_frames" ] = False
1177-
1178- elif metadata ["total_num_frames" ] != len (video_array ):
1179- logger .warning (
1180- "Total frames in metadata "
1181- "(%s) does not match the length of "
1182- "video array %s. This can "
1183- "be because the video is resampled "
1184- "in advance. This may cause "
1185- "a divergence with HF implementation." ,
1186- metadata ["total_num_frames" ],
1187- len (video_array ),
1188- )
1189- metadata ["total_num_frames" ] = len (video_array )
1183+ # don't update mm_kwargs inplace
1184+ video_mm_kwargs = dict (** mm_kwargs )
1185+ video_mm_kwargs ["do_sample_frames" ] = metadata .get (
1186+ "do_sample_frames" , True )
11901187
11911188 video_mm_data = dict ()
11921189 video_mm_data ["videos" ] = [[video_array ]]
1193- video_mm_data ["video_metadata" ] = [[VideoMetadata (** metadata )]]
1190+
1191+ # backward compatibility for Transformers 4.55
1192+ unuse_metadata = ["do_sample_frames" ]
1193+ if not hasattr (
1194+ VideoMetadata ,
1195+ "frames_indices" ) and "frames_indices" in metadata :
1196+ unuse_metadata .append ("frames_indices" )
1197+
1198+ video_mm_data ["video_metadata" ] = [[
1199+ VideoMetadata (
1200+ ** {
1201+ k : metadata [k ]
1202+ for k in metadata if k not in unuse_metadata
1203+ })
1204+ ]]
11941205
11951206 video_outputs = super ()._call_hf_processor (
11961207 prompt = "<|begin_of_video|><|video|><|end_of_video|>" ,
11971208 mm_data = video_mm_data ,
1198- mm_kwargs = mm_kwargs ,
1209+ mm_kwargs = video_mm_kwargs ,
11991210 tok_kwargs = tok_kwargs ,
12001211 )
1201- if "do_sample_frames" in mm_kwargs and not mm_kwargs [
1202- "do_sample_frames" ] :
1212+ if not video_mm_kwargs [ "do_sample_frames" ] and Version (
1213+ TRANSFORMERS_VERSION ) < Version ( "4.56.0" ) :
12031214 # Transformers v4.55 has incorrect timestamps issue for
12041215 # skip sampling. We construct the placeholder manually to
12051216 # get placeholders with correct timestamps.
@@ -1218,6 +1229,7 @@ def _call_hf_processor(
12181229 prompt = prompt .replace (
12191230 "<|begin_of_video|><|video|><|end_of_video|>" ,
12201231 video_placeholder ,
1232+ 1 ,
12211233 )
12221234
12231235 video_grid_thw_lst .append (video_outputs ["video_grid_thw" ])
0 commit comments