Skip to content

Commit 1bdf928

Browse files
authored
Fix pts -> index conversion in get_frame_at_timestamps[_in_range] (#287)
1 parent b841eb3 commit 1bdf928

File tree

3 files changed

+66
-7
lines changed

3 files changed

+66
-7
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,12 +1125,6 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps(
11251125
return ptsToSeconds(info.nextPts, stream.timeBase) <= framePts;
11261126
});
11271127
int64_t frameIndex = it - stream.allFrames.begin();
1128-
// If the frame index is larger than the size of allFrames, that means we
1129-
// couldn't match the pts value to the pts value of a NEXT FRAME. And
1130-
// that means that this timestamp falls during the time between when the
1131-
// last frame is displayed, and the video ends. Hence, it should map to the
1132-
// index of the last frame.
1133-
frameIndex = std::min(frameIndex, (int64_t)stream.allFrames.size() - 1);
11341128
frameIndices[i] = frameIndex;
11351129
}
11361130

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,12 @@ class VideoDecoder {
299299
private:
300300
struct FrameInfo {
301301
int64_t pts = 0;
302-
int64_t nextPts = 0;
302+
// The value of this default is important: the last frame's nextPts will be
303+
// INT64_MAX, which ensures that the allFrames vec contains FrameInfo
304+
// structs with *increasing* nextPts values. That's a necessary condition
305+
// for the binary searches on those values to work properly (as typically
306+
// done during pts -> index conversions.)
307+
int64_t nextPts = INT64_MAX;
303308
};
304309
struct FilterState {
305310
UniqueAVFilterGraph filterGraph;

test/decoders/test_video_decoder_ops.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,66 @@ def test_get_frames_by_pts(self):
186186
with pytest.raises(AssertionError):
187187
assert_tensor_equal(frames[0], frames[-1])
188188

189+
def test_pts_apis_against_index_ref(self):
190+
# Non-regression test for https://github.com/pytorch/torchcodec/pull/287
191+
# Get all frames in the video, then query all frames with all time-based
192+
# APIs exactly where those frames are supposed to start. We assert that
193+
# we get the expected frame.
194+
decoder = create_from_file(str(NASA_VIDEO.path))
195+
scan_all_streams_to_update_metadata(decoder)
196+
add_video_stream(decoder)
197+
198+
metadata = get_json_metadata(decoder)
199+
metadata_dict = json.loads(metadata)
200+
num_frames = metadata_dict["numFrames"]
201+
assert num_frames == 390
202+
203+
stream_index = 3
204+
_, all_pts_seconds_ref, _ = zip(
205+
*[
206+
get_frame_at_index(
207+
decoder, stream_index=stream_index, frame_index=frame_index
208+
)
209+
for frame_index in range(num_frames)
210+
]
211+
)
212+
all_pts_seconds_ref = torch.tensor(all_pts_seconds_ref)
213+
214+
assert len(all_pts_seconds_ref.unique() == len(all_pts_seconds_ref))
215+
216+
_, pts_seconds, _ = zip(
217+
*[get_frame_at_pts(decoder, seconds=pts) for pts in all_pts_seconds_ref]
218+
)
219+
pts_seconds = torch.tensor(pts_seconds)
220+
assert_tensor_equal(pts_seconds, all_pts_seconds_ref)
221+
222+
_, pts_seconds, _ = get_frames_by_pts_in_range(
223+
decoder,
224+
stream_index=stream_index,
225+
start_seconds=0,
226+
stop_seconds=all_pts_seconds_ref[-1] + 1e-4,
227+
)
228+
assert_tensor_equal(pts_seconds, all_pts_seconds_ref)
229+
230+
_, pts_seconds, _ = zip(
231+
*[
232+
get_frames_by_pts_in_range(
233+
decoder,
234+
stream_index=stream_index,
235+
start_seconds=pts,
236+
stop_seconds=pts + 1e-4,
237+
)
238+
for pts in all_pts_seconds_ref
239+
]
240+
)
241+
pts_seconds = torch.tensor(pts_seconds)
242+
assert_tensor_equal(pts_seconds, all_pts_seconds_ref)
243+
244+
_, pts_seconds, _ = get_frames_by_pts(
245+
decoder, stream_index=stream_index, timestamps=all_pts_seconds_ref.tolist()
246+
)
247+
assert_tensor_equal(pts_seconds, all_pts_seconds_ref)
248+
189249
def test_get_frames_in_range(self):
190250
decoder = create_from_file(str(NASA_VIDEO.path))
191251
scan_all_streams_to_update_metadata(decoder)

0 commit comments

Comments
 (0)