@@ -606,25 +606,34 @@ FrameOutput SingleStreamDecoder::getFrameAtIndexInternal(
606606}
607607
608608FrameBatchOutput SingleStreamDecoder::getFramesAtIndices (
609- const std::vector< int64_t > & frameIndices) {
609+ const torch::Tensor & frameIndices) {
610610 validateActiveStream (AVMEDIA_TYPE_VIDEO);
611611
612- auto indicesAreSorted =
613- std::is_sorted (frameIndices.begin (), frameIndices.end ());
612+ auto frameIndicesAccessor = frameIndices.accessor <int64_t , 1 >();
613+
614+ bool indicesAreSorted = true ;
615+ for (int64_t i = 1 ; i < frameIndices.numel (); ++i) {
616+ if (frameIndicesAccessor[i] < frameIndicesAccessor[i - 1 ]) {
617+ indicesAreSorted = false ;
618+ break ;
619+ }
620+ }
614621
615622 std::vector<size_t > argsort;
616623 if (!indicesAreSorted) {
617624 // if frameIndices is [13, 10, 12, 11]
618625 // when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want
619626 // to use to decode the frames
620627 // and argsort is [ 1, 3, 2, 0]
621- argsort.resize (frameIndices.size ());
628+ argsort.resize (frameIndices.numel ());
622629 for (size_t i = 0 ; i < argsort.size (); ++i) {
623630 argsort[i] = i;
624631 }
625632 std::sort (
626- argsort.begin (), argsort.end (), [&frameIndices](size_t a, size_t b) {
627- return frameIndices[a] < frameIndices[b];
633+ argsort.begin (),
634+ argsort.end (),
635+ [&frameIndicesAccessor](size_t a, size_t b) {
636+ return frameIndicesAccessor[a] < frameIndicesAccessor[b];
628637 });
629638 }
630639
@@ -633,12 +642,12 @@ FrameBatchOutput SingleStreamDecoder::getFramesAtIndices(
633642 const auto & streamInfo = streamInfos_[activeStreamIndex_];
634643 const auto & videoStreamOptions = streamInfo.videoStreamOptions ;
635644 FrameBatchOutput frameBatchOutput (
636- frameIndices.size (), videoStreamOptions, streamMetadata);
645+ frameIndices.numel (), videoStreamOptions, streamMetadata);
637646
638647 auto previousIndexInVideo = -1 ;
639- for (size_t f = 0 ; f < frameIndices.size (); ++f) {
648+ for (int64_t f = 0 ; f < frameIndices.numel (); ++f) {
640649 auto indexInOutput = indicesAreSorted ? f : argsort[f];
641- auto indexInVideo = frameIndices [indexInOutput];
650+ auto indexInVideo = frameIndicesAccessor [indexInOutput];
642651
643652 if ((f > 0 ) && (indexInVideo == previousIndexInVideo)) {
644653 // Avoid decoding the same frame twice
@@ -780,7 +789,8 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
780789 frameIndices[i] = secondsToIndexLowerBound (frameSeconds);
781790 }
782791
783- return getFramesAtIndices (frameIndices);
792+ // TODO: Support tensors natively instead of a vector to avoid a copy.
793+ return getFramesAtIndices (torch::tensor (frameIndices));
784794}
785795
786796FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange (
@@ -1202,6 +1212,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
12021212 if (status == AVERROR_EOF) {
12031213 // End of file reached. We must drain the decoder
12041214 if (useCustomInterface) {
1215+ // TODONVDEC P0: Re-think this. This should be simpler.
12051216 AutoAVPacket eofAutoPacket;
12061217 ReferenceAVPacket eofPacket (eofAutoPacket);
12071218 eofPacket->data = nullptr ;
0 commit comments