@@ -34,6 +34,31 @@ double ptsToSeconds(int64_t pts, const AVRational& timeBase) {
3434 return ptsToSeconds (pts, timeBase.den );
3535}
3636
37+ // Returns a [N]CHW *view* of a [N]HWC input tensor, if the options require so.
38+ // The [N] leading batch-dimension is optional i.e. the input tensor can be 3D
39+ // or 4D.
40+ // Calling permute() is guaranteed to return a view as per the docs:
41+ // https://pytorch.org/docs/stable/generated/torch.permute.html
42+ torch::Tensor MaybePermuteHWC2CHW (
43+ const VideoDecoder::VideoStreamDecoderOptions& options,
44+ torch::Tensor& hwcTensor) {
45+ if (options.dimensionOrder == " NHWC" ) {
46+ return hwcTensor;
47+ }
48+ auto numDimensions = hwcTensor.dim ();
49+ auto shape = hwcTensor.sizes ();
50+ if (numDimensions == 3 ) {
51+ TORCH_CHECK (shape[2 ] == 3 , " Not a HWC tensor: " , shape);
52+ return hwcTensor.permute ({2 , 0 , 1 });
53+ } else if (numDimensions == 4 ) {
54+ TORCH_CHECK (shape[3 ] == 3 , " Not a NHWC tensor: " , shape);
55+ return hwcTensor.permute ({0 , 3 , 1 , 2 });
56+ } else {
57+ TORCH_CHECK (
58+ false , " Expected tensor with 3 or 4 dimensions, got " , numDimensions);
59+ }
60+ }
61+
3762struct AVInput {
3863 UniqueAVFormatContext formatContext;
3964 std::unique_ptr<AVIOBytesContext> ioBytesContext;
@@ -167,28 +192,13 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput(
167192 const VideoStreamDecoderOptions& options,
168193 const StreamMetadata& metadata)
169194 : ptsSeconds(torch::empty({numFrames}, {torch::kFloat64 })),
170- durationSeconds (torch::empty({numFrames}, {torch::kFloat64 })) {
171- if (options.dimensionOrder == " NHWC" ) {
172- frames = torch::empty (
173- {numFrames,
174- options.height .value_or (*metadata.height ),
175- options.width .value_or (*metadata.width ),
176- 3 },
177- {torch::kUInt8 });
178- } else if (options.dimensionOrder == " NCHW" ) {
179- frames = torch::empty (
180- {numFrames,
181- 3 ,
182- options.height .value_or (*metadata.height ),
183- options.width .value_or (*metadata.width )},
184- torch::TensorOptions ()
185- .memory_format (torch::MemoryFormat::ChannelsLast)
186- .dtype ({torch::kUInt8 }));
187- } else {
188- TORCH_CHECK (
189- false , " Unsupported frame dimensionOrder =" + options.dimensionOrder )
190- }
191- }
195+ durationSeconds (torch::empty({numFrames}, {torch::kFloat64 })),
196+ frames (torch::empty(
197+ {numFrames,
198+ options.height .value_or (*metadata.height ),
199+ options.width .value_or (*metadata.width ),
200+ 3 },
201+ {torch::kUInt8 })) {}
192202
193203VideoDecoder::VideoDecoder () {}
194204
@@ -890,22 +900,27 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
890900 if (output.streamType == AVMEDIA_TYPE_VIDEO) {
891901 if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
892902 torch::Tensor tensor;
903+ int width = streamInfo.options .width .value_or (frame->width );
904+ int height = streamInfo.options .height .value_or (frame->height );
893905 if (preAllocatedOutputTensor.has_value ()) {
894- // TODO: check shape of preAllocatedOutputTensor?
895906 tensor = preAllocatedOutputTensor.value ();
907+ auto shape = tensor.sizes ();
908+ TORCH_CHECK (
909+ (shape.size () == 3 ) && (shape[0 ] == height) &&
910+ (shape[1 ] == width) && (shape[2 ] == 3 ),
911+ " Expected tensor of shape " ,
912+ height,
913+ " x" ,
914+ width,
915+ " x3, got " ,
916+ shape);
896917 } else {
897- int width = streamInfo.options .width .value_or (frame->width );
898- int height = streamInfo.options .height .value_or (frame->height );
899918 tensor = torch::empty (
900919 {height, width, 3 }, torch::TensorOptions ().dtype ({torch::kUInt8 }));
901920 }
902-
903921 rawOutput.data = tensor.data_ptr <uint8_t >();
904922 convertFrameToBufferUsingSwsScale (rawOutput);
905923
906- if (streamInfo.options .dimensionOrder == " NCHW" ) {
907- tensor = tensor.permute ({2 , 0 , 1 });
908- }
909924 output.frame = tensor;
910925 } else if (
911926 streamInfo.colorConversionLibrary ==
@@ -916,6 +931,14 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
916931 " Invalid color conversion library: " +
917932 std::to_string (static_cast <int >(streamInfo.colorConversionLibrary )));
918933 }
934+ if (!preAllocatedOutputTensor.has_value ()) {
935+ // We only convert to CHW if a pre-allocated tensor wasn't passed. When a
936+ // pre-allocated tensor is passed, it's up to the caller (typically a
937+ // batch API) to do the conversion. This is more efficient as it allows
938+ // batch NHWC tensors to be permuted only once, instead of permuting HWC
939+ // tensors N times.
940+ output.frame = MaybePermuteHWC2CHW (streamInfo.options , output.frame );
941+ }
919942
920943 } else if (output.streamType == AVMEDIA_TYPE_AUDIO) {
921944 // TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement
@@ -1046,6 +1069,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
10461069 }
10471070 i++;
10481071 }
1072+ output.frames = MaybePermuteHWC2CHW (options, output.frames );
10491073 return output;
10501074}
10511075
@@ -1081,7 +1105,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
10811105 output.ptsSeconds [f] = singleOut.ptsSeconds ;
10821106 output.durationSeconds [f] = singleOut.durationSeconds ;
10831107 }
1084-
1108+ output. frames = MaybePermuteHWC2CHW (options, output. frames );
10851109 return output;
10861110}
10871111
@@ -1134,6 +1158,7 @@ VideoDecoder::getFramesDisplayedByTimestampInRange(
11341158 // need this special case below.
11351159 if (startSeconds == stopSeconds) {
11361160 BatchDecodedOutput output (0 , options, streamMetadata);
1161+ output.frames = MaybePermuteHWC2CHW (options, output.frames );
11371162 return output;
11381163 }
11391164
@@ -1176,6 +1201,7 @@ VideoDecoder::getFramesDisplayedByTimestampInRange(
11761201 output.ptsSeconds [f] = singleOut.ptsSeconds ;
11771202 output.durationSeconds [f] = singleOut.durationSeconds ;
11781203 }
1204+ output.frames = MaybePermuteHWC2CHW (options, output.frames );
11791205
11801206 return output;
11811207}
@@ -1302,11 +1328,6 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph(
13021328 torch::Tensor tensor = torch::from_blob (
13031329 filteredFramePtr->data [0 ], shape, strides, deleter, {torch::kUInt8 });
13041330 StreamInfo& activeStream = streams_[streamIndex];
1305- if (activeStream.options .dimensionOrder == " NCHW" ) {
1306- // The docs guaranty this to return a view:
1307- // https://pytorch.org/docs/stable/generated/torch.permute.html
1308- tensor = tensor.permute ({2 , 0 , 1 });
1309- }
13101331 return tensor;
13111332}
13121333
0 commit comments