Skip to content

Commit b82ea81

Browse files
authored
Rework HWC / CHW dimension order conversions (#277)
1 parent d6cbee5 commit b82ea81

File tree

2 files changed

+62
-35
lines changed

2 files changed

+62
-35
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 56 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,31 @@ double ptsToSeconds(int64_t pts, const AVRational& timeBase) {
3434
return ptsToSeconds(pts, timeBase.den);
3535
}
3636

37+
// Returns a [N]CHW *view* of a [N]HWC input tensor, if the options require so.
38+
// The [N] leading batch-dimension is optional i.e. the input tensor can be 3D
39+
// or 4D.
40+
// Calling permute() is guaranteed to return a view as per the docs:
41+
// https://pytorch.org/docs/stable/generated/torch.permute.html
42+
torch::Tensor MaybePermuteHWC2CHW(
43+
const VideoDecoder::VideoStreamDecoderOptions& options,
44+
torch::Tensor& hwcTensor) {
45+
if (options.dimensionOrder == "NHWC") {
46+
return hwcTensor;
47+
}
48+
auto numDimensions = hwcTensor.dim();
49+
auto shape = hwcTensor.sizes();
50+
if (numDimensions == 3) {
51+
TORCH_CHECK(shape[2] == 3, "Not a HWC tensor: ", shape);
52+
return hwcTensor.permute({2, 0, 1});
53+
} else if (numDimensions == 4) {
54+
TORCH_CHECK(shape[3] == 3, "Not a NHWC tensor: ", shape);
55+
return hwcTensor.permute({0, 3, 1, 2});
56+
} else {
57+
TORCH_CHECK(
58+
false, "Expected tensor with 3 or 4 dimensions, got ", numDimensions);
59+
}
60+
}
61+
3762
struct AVInput {
3863
UniqueAVFormatContext formatContext;
3964
std::unique_ptr<AVIOBytesContext> ioBytesContext;
@@ -167,28 +192,13 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput(
167192
const VideoStreamDecoderOptions& options,
168193
const StreamMetadata& metadata)
169194
: ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})),
170-
durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) {
171-
if (options.dimensionOrder == "NHWC") {
172-
frames = torch::empty(
173-
{numFrames,
174-
options.height.value_or(*metadata.height),
175-
options.width.value_or(*metadata.width),
176-
3},
177-
{torch::kUInt8});
178-
} else if (options.dimensionOrder == "NCHW") {
179-
frames = torch::empty(
180-
{numFrames,
181-
3,
182-
options.height.value_or(*metadata.height),
183-
options.width.value_or(*metadata.width)},
184-
torch::TensorOptions()
185-
.memory_format(torch::MemoryFormat::ChannelsLast)
186-
.dtype({torch::kUInt8}));
187-
} else {
188-
TORCH_CHECK(
189-
false, "Unsupported frame dimensionOrder =" + options.dimensionOrder)
190-
}
191-
}
195+
durationSeconds(torch::empty({numFrames}, {torch::kFloat64})),
196+
frames(torch::empty(
197+
{numFrames,
198+
options.height.value_or(*metadata.height),
199+
options.width.value_or(*metadata.width),
200+
3},
201+
{torch::kUInt8})) {}
192202

193203
VideoDecoder::VideoDecoder() {}
194204

@@ -890,22 +900,27 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
890900
if (output.streamType == AVMEDIA_TYPE_VIDEO) {
891901
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
892902
torch::Tensor tensor;
903+
int width = streamInfo.options.width.value_or(frame->width);
904+
int height = streamInfo.options.height.value_or(frame->height);
893905
if (preAllocatedOutputTensor.has_value()) {
894-
// TODO: check shape of preAllocatedOutputTensor?
895906
tensor = preAllocatedOutputTensor.value();
907+
auto shape = tensor.sizes();
908+
TORCH_CHECK(
909+
(shape.size() == 3) && (shape[0] == height) &&
910+
(shape[1] == width) && (shape[2] == 3),
911+
"Expected tensor of shape ",
912+
height,
913+
"x",
914+
width,
915+
"x3, got ",
916+
shape);
896917
} else {
897-
int width = streamInfo.options.width.value_or(frame->width);
898-
int height = streamInfo.options.height.value_or(frame->height);
899918
tensor = torch::empty(
900919
{height, width, 3}, torch::TensorOptions().dtype({torch::kUInt8}));
901920
}
902-
903921
rawOutput.data = tensor.data_ptr<uint8_t>();
904922
convertFrameToBufferUsingSwsScale(rawOutput);
905923

906-
if (streamInfo.options.dimensionOrder == "NCHW") {
907-
tensor = tensor.permute({2, 0, 1});
908-
}
909924
output.frame = tensor;
910925
} else if (
911926
streamInfo.colorConversionLibrary ==
@@ -916,6 +931,14 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
916931
"Invalid color conversion library: " +
917932
std::to_string(static_cast<int>(streamInfo.colorConversionLibrary)));
918933
}
934+
if (!preAllocatedOutputTensor.has_value()) {
935+
// We only convert to CHW if a pre-allocated tensor wasn't passed. When a
936+
// pre-allocated tensor is passed, it's up to the caller (typically a
937+
// batch API) to do the conversion. This is more efficient as it allows
938+
// batch NHWC tensors to be permuted only once, instead of permuting HWC
939+
// tensors N times.
940+
output.frame = MaybePermuteHWC2CHW(streamInfo.options, output.frame);
941+
}
919942

920943
} else if (output.streamType == AVMEDIA_TYPE_AUDIO) {
921944
// TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement
@@ -1046,6 +1069,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
10461069
}
10471070
i++;
10481071
}
1072+
output.frames = MaybePermuteHWC2CHW(options, output.frames);
10491073
return output;
10501074
}
10511075

@@ -1081,7 +1105,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
10811105
output.ptsSeconds[f] = singleOut.ptsSeconds;
10821106
output.durationSeconds[f] = singleOut.durationSeconds;
10831107
}
1084-
1108+
output.frames = MaybePermuteHWC2CHW(options, output.frames);
10851109
return output;
10861110
}
10871111

@@ -1134,6 +1158,7 @@ VideoDecoder::getFramesDisplayedByTimestampInRange(
11341158
// need this special case below.
11351159
if (startSeconds == stopSeconds) {
11361160
BatchDecodedOutput output(0, options, streamMetadata);
1161+
output.frames = MaybePermuteHWC2CHW(options, output.frames);
11371162
return output;
11381163
}
11391164

@@ -1176,6 +1201,7 @@ VideoDecoder::getFramesDisplayedByTimestampInRange(
11761201
output.ptsSeconds[f] = singleOut.ptsSeconds;
11771202
output.durationSeconds[f] = singleOut.durationSeconds;
11781203
}
1204+
output.frames = MaybePermuteHWC2CHW(options, output.frames);
11791205

11801206
return output;
11811207
}
@@ -1302,11 +1328,6 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph(
13021328
torch::Tensor tensor = torch::from_blob(
13031329
filteredFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
13041330
StreamInfo& activeStream = streams_[streamIndex];
1305-
if (activeStream.options.dimensionOrder == "NCHW") {
1306-
// The docs guaranty this to return a view:
1307-
// https://pytorch.org/docs/stable/generated/torch.permute.html
1308-
tensor = tensor.permute({2, 0, 1});
1309-
}
13101331
return tensor;
13111332
}
13121333

test/decoders/test_video_decoder_ops.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,12 @@ def test_color_conversion_library_with_dimension_order(
425425
assert frames.shape[1:] == expected_shape
426426
assert_tensor_equal(frames[0], frame0_ref)
427427

428+
frames = get_frames_at_indices(
429+
decoder, stream_index=stream_index, frame_indices=[0, 1, 3, 4]
430+
)
431+
assert frames.shape[1:] == expected_shape
432+
assert_tensor_equal(frames[0], frame0_ref)
433+
428434
@pytest.mark.parametrize(
429435
"width_scaling_factor,height_scaling_factor",
430436
((1.31, 1.5), (0.71, 0.5), (1.31, 0.7), (0.71, 1.5), (1.0, 1.0)),

0 commit comments

Comments
 (0)