diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index e79a194c..2561d84e 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -847,7 +847,8 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter( } VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( - VideoDecoder::RawDecodedOutput& rawOutput) { + VideoDecoder::RawDecodedOutput& rawOutput, + std::optional preAllocatedOutputTensor) { // Convert the frame to tensor. DecodedOutput output; int streamIndex = rawOutput.streamIndex; @@ -862,8 +863,10 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( output.durationSeconds = ptsToSeconds( getDuration(frame), formatContext_->streams[streamIndex]->time_base); if (streamInfo.options.device.type() == torch::kCPU) { - convertAVFrameToDecodedOutputOnCPU(rawOutput, output); + convertAVFrameToDecodedOutputOnCPU( + rawOutput, output, preAllocatedOutputTensor); } else if (streamInfo.options.device.type() == torch::kCUDA) { + // TODO: handle pre-allocated output tensor convertAVFrameToDecodedOutputOnCuda( streamInfo.options.device, streamInfo.options, @@ -879,16 +882,24 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( VideoDecoder::RawDecodedOutput& rawOutput, - DecodedOutput& output) { + DecodedOutput& output, + std::optional preAllocatedOutputTensor) { int streamIndex = rawOutput.streamIndex; AVFrame* frame = rawOutput.frame.get(); auto& streamInfo = streams_[streamIndex]; if (output.streamType == AVMEDIA_TYPE_VIDEO) { if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) { - int width = streamInfo.options.width.value_or(frame->width); - int height = streamInfo.options.height.value_or(frame->height); - torch::Tensor tensor = torch::empty( - {height, width, 3}, torch::TensorOptions().dtype({torch::kUInt8})); + torch::Tensor tensor; + if (preAllocatedOutputTensor.has_value()) { + // TODO: check shape of preAllocatedOutputTensor? + tensor = preAllocatedOutputTensor.value(); + } else { + int width = streamInfo.options.width.value_or(frame->width); + int height = streamInfo.options.height.value_or(frame->height); + tensor = torch::empty( + {height, width, 3}, torch::TensorOptions().dtype({torch::kUInt8})); + } + rawOutput.data = tensor.data_ptr(); convertFrameToBufferUsingSwsScale(rawOutput); @@ -981,7 +992,8 @@ void VideoDecoder::validateFrameIndex( VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex( int streamIndex, - int64_t frameIndex) { + int64_t frameIndex, + std::optional preAllocatedOutputTensor) { validateUserProvidedStreamIndex(streamIndex); validateScannedAllStreams("getFrameAtIndex"); @@ -990,7 +1002,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex( int64_t pts = stream.allFrames[frameIndex].pts; setCursorPtsInSeconds(ptsToSeconds(pts, stream.timeBase)); - return getNextDecodedOutputNoDemux(); + return getNextDecodedOutputNoDemux(preAllocatedOutputTensor); } VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( @@ -1062,8 +1074,10 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( BatchDecodedOutput output(numOutputFrames, options, streamMetadata); for (int64_t i = start, f = 0; i < stop; i += step, ++f) { - DecodedOutput singleOut = getFrameAtIndex(streamIndex, i); - output.frames[f] = singleOut.frame; + DecodedOutput singleOut = getFrameAtIndex(streamIndex, i, output.frames[f]); + if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { + output.frames[f] = singleOut.frame; + } output.ptsSeconds[f] = singleOut.ptsSeconds; output.durationSeconds[f] = singleOut.durationSeconds; } @@ -1155,8 +1169,10 @@ VideoDecoder::getFramesDisplayedByTimestampInRange( int64_t numFrames = stopFrameIndex - startFrameIndex; BatchDecodedOutput output(numFrames, options, streamMetadata); for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) { - DecodedOutput singleOut = getFrameAtIndex(streamIndex, i); - output.frames[f] = singleOut.frame; + DecodedOutput singleOut = getFrameAtIndex(streamIndex, i, output.frames[f]); + if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { + output.frames[f] = singleOut.frame; + } output.ptsSeconds[f] = singleOut.ptsSeconds; output.durationSeconds[f] = singleOut.durationSeconds; } @@ -1173,9 +1189,10 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() { return rawOutput; } -VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux() { +VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux( + std::optional preAllocatedOutputTensor) { auto rawOutput = getNextRawDecodedOutputNoDemux(); - return convertAVFrameToDecodedOutput(rawOutput); + return convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor); } void VideoDecoder::setCursorPtsInSeconds(double seconds) { diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index c4da3c61..2adbfac6 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -214,7 +214,8 @@ class VideoDecoder { }; // Decodes the frame where the current cursor position is. It also advances // the cursor to the next frame. - DecodedOutput getNextDecodedOutputNoDemux(); + DecodedOutput getNextDecodedOutputNoDemux( + std::optional preAllocatedOutputTensor = std::nullopt); // Decodes the first frame in any added stream that is visible at a given // timestamp. Frames in the video have a presentation timestamp and a // duration. For example, if a frame has presentation timestamp of 5.0s and a @@ -222,7 +223,10 @@ class VideoDecoder { // i.e. it will be returned when this function is called with seconds=5.0 or // seconds=5.999, etc. DecodedOutput getFrameDisplayedAtTimestampNoDemux(double seconds); - DecodedOutput getFrameAtIndex(int streamIndex, int64_t frameIndex); + DecodedOutput getFrameAtIndex( + int streamIndex, + int64_t frameIndex, + std::optional preAllocatedOutputTensor = std::nullopt); struct BatchDecodedOutput { torch::Tensor frames; torch::Tensor ptsSeconds; @@ -363,10 +367,13 @@ class VideoDecoder { int streamIndex, const AVFrame* frame); void convertFrameToBufferUsingSwsScale(RawDecodedOutput& rawOutput); - DecodedOutput convertAVFrameToDecodedOutput(RawDecodedOutput& rawOutput); + DecodedOutput convertAVFrameToDecodedOutput( + RawDecodedOutput& rawOutput, + std::optional preAllocatedOutputTensor = std::nullopt); void convertAVFrameToDecodedOutputOnCPU( RawDecodedOutput& rawOutput, - DecodedOutput& output); + DecodedOutput& output, + std::optional preAllocatedOutputTensor = std::nullopt); DecoderOptions options_; ContainerMetadata containerMetadata_; diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index 1bb28feb..18782a36 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -27,6 +27,7 @@ get_frame_at_index, get_frame_at_pts, get_frames_at_indices, + get_frames_by_pts_in_range, get_frames_in_range, get_json_metadata, get_next_frame, @@ -383,6 +384,47 @@ def test_color_conversion_library_with_scaling( swscale_frame0, _, _ = get_next_frame(swscale_decoder) assert_tensor_equal(filtergraph_frame0, swscale_frame0) + @pytest.mark.parametrize("dimension_order", ("NHWC", "NCHW")) + @pytest.mark.parametrize("color_conversion_library", ("filtergraph", "swscale")) + def test_color_conversion_library_with_dimension_order( + self, dimension_order, color_conversion_library + ): + decoder = create_from_file(str(NASA_VIDEO.path)) + _add_video_stream( + decoder, + color_conversion_library=color_conversion_library, + dimension_order=dimension_order, + ) + scan_all_streams_to_update_metadata(decoder) + + frame0_ref = NASA_VIDEO.get_frame_data_by_index(0) + if dimension_order == "NHWC": + frame0_ref = frame0_ref.permute(1, 2, 0) + expected_shape = frame0_ref.shape + + stream_index = 3 + frame0, *_ = get_frame_at_index( + decoder, stream_index=stream_index, frame_index=0 + ) + assert frame0.shape == expected_shape + assert_tensor_equal(frame0, frame0_ref) + + frame0, *_ = get_frame_at_pts(decoder, seconds=0.0) + assert frame0.shape == expected_shape + assert_tensor_equal(frame0, frame0_ref) + + frames, *_ = get_frames_in_range( + decoder, stream_index=stream_index, start=0, stop=3 + ) + assert frames.shape[1:] == expected_shape + assert_tensor_equal(frames[0], frame0_ref) + + frames, *_ = get_frames_by_pts_in_range( + decoder, stream_index=stream_index, start_seconds=0, stop_seconds=1 + ) + assert frames.shape[1:] == expected_shape + assert_tensor_equal(frames[0], frame0_ref) + @pytest.mark.parametrize( "width_scaling_factor,height_scaling_factor", ((1.31, 1.5), (0.71, 0.5), (1.31, 0.7), (0.71, 1.5), (1.0, 1.0)),