@@ -186,6 +186,66 @@ def test_get_frames_by_pts(self):
186186 with pytest .raises (AssertionError ):
187187 assert_tensor_equal (frames [0 ], frames [- 1 ])
188188
189+ def test_pts_apis_against_index_ref (self ):
190+ # Non-regression test for https://github.com/pytorch/torchcodec/pull/287
191+ # Get all frames in the video, then query all frames with all time-based
192+ # APIs exactly where those frames are supposed to start. We assert that
193+ # we get the expected frame.
194+ decoder = create_from_file (str (NASA_VIDEO .path ))
195+ scan_all_streams_to_update_metadata (decoder )
196+ add_video_stream (decoder )
197+
198+ metadata = get_json_metadata (decoder )
199+ metadata_dict = json .loads (metadata )
200+ num_frames = metadata_dict ["numFrames" ]
201+ assert num_frames == 390
202+
203+ stream_index = 3
204+ _ , all_pts_seconds_ref , _ = zip (
205+ * [
206+ get_frame_at_index (
207+ decoder , stream_index = stream_index , frame_index = frame_index
208+ )
209+ for frame_index in range (num_frames )
210+ ]
211+ )
212+ all_pts_seconds_ref = torch .tensor (all_pts_seconds_ref )
213+
214+ assert len (all_pts_seconds_ref .unique () == len (all_pts_seconds_ref ))
215+
216+ _ , pts_seconds , _ = zip (
217+ * [get_frame_at_pts (decoder , seconds = pts ) for pts in all_pts_seconds_ref ]
218+ )
219+ pts_seconds = torch .tensor (pts_seconds )
220+ assert_tensor_equal (pts_seconds , all_pts_seconds_ref )
221+
222+ _ , pts_seconds , _ = get_frames_by_pts_in_range (
223+ decoder ,
224+ stream_index = stream_index ,
225+ start_seconds = 0 ,
226+ stop_seconds = all_pts_seconds_ref [- 1 ] + 1e-4 ,
227+ )
228+ assert_tensor_equal (pts_seconds , all_pts_seconds_ref )
229+
230+ _ , pts_seconds , _ = zip (
231+ * [
232+ get_frames_by_pts_in_range (
233+ decoder ,
234+ stream_index = stream_index ,
235+ start_seconds = pts ,
236+ stop_seconds = pts + 1e-4 ,
237+ )
238+ for pts in all_pts_seconds_ref
239+ ]
240+ )
241+ pts_seconds = torch .tensor (pts_seconds )
242+ assert_tensor_equal (pts_seconds , all_pts_seconds_ref )
243+
244+ _ , pts_seconds , _ = get_frames_by_pts (
245+ decoder , stream_index = stream_index , timestamps = all_pts_seconds_ref .tolist ()
246+ )
247+ assert_tensor_equal (pts_seconds , all_pts_seconds_ref )
248+
189249 def test_get_frames_in_range (self ):
190250 decoder = create_from_file (str (NASA_VIDEO .path ))
191251 scan_all_streams_to_update_metadata (decoder )
0 commit comments