From 2f70e2f90708800d28128671d4e04744c87ab97e Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Wed, 13 Apr 2022 21:12:12 -0700 Subject: [PATCH] Add YUV420P format support to Streamer API (#2334) Summary: This commit adds YUV420P format support to Streamer API. When the native format of a video is YUV420P, the Streamer will output Tensor of YUV color channel. Pull Request resolved: https://github.com/pytorch/audio/pull/2334 Reviewed By: hwangjeff Differential Revision: D35632916 Pulled By: mthrok fbshipit-source-id: a7a0078788433060266b8bd3e7cad023f41389f5 --- torchaudio/csrc/ffmpeg/buffer.cpp | 51 ++++++++++++++++++++++++++++ torchaudio/csrc/ffmpeg/prototype.cpp | 31 ++++++++++++----- torchaudio/prototype/io/streamer.py | 1 + 3 files changed, 75 insertions(+), 8 deletions(-) diff --git a/torchaudio/csrc/ffmpeg/buffer.cpp b/torchaudio/csrc/ffmpeg/buffer.cpp index 0062d1f746..d7fbd6acdd 100644 --- a/torchaudio/csrc/ffmpeg/buffer.cpp +++ b/torchaudio/csrc/ffmpeg/buffer.cpp @@ -164,6 +164,55 @@ void AudioBuffer::push_frame(AVFrame* frame) { // Modifiers - Push Video //////////////////////////////////////////////////////////////////////////////// namespace { +torch::Tensor convert_yuv420p(AVFrame* pFrame) { + int width = pFrame->width; + int height = pFrame->height; + + auto options = torch::TensorOptions() + .dtype(torch::kUInt8) + .layout(torch::kStrided) + .device(torch::kCPU); + + torch::Tensor y = torch::empty({1, height, width, 1}, options); + { + uint8_t* tgt = y.data_ptr(); + uint8_t* src = pFrame->data[0]; + int linesize = pFrame->linesize[0]; + for (int h = 0; h < height; ++h) { + memcpy(tgt, src, width); + tgt += width; + src += linesize; + } + } + torch::Tensor u = torch::empty({1, height / 2, width / 2, 1}, options); + { + uint8_t* tgt = u.data_ptr(); + uint8_t* src = pFrame->data[1]; + int linesize = pFrame->linesize[1]; + for (int h = 0; h < height / 2; ++h) { + memcpy(tgt, src, width / 2); + tgt += width / 2; + src += linesize; + } + } + torch::Tensor v = torch::empty({1, height / 2, width / 2, 1}, options); + { + uint8_t* tgt = v.data_ptr(); + uint8_t* src = pFrame->data[2]; + int linesize = pFrame->linesize[2]; + for (int h = 0; h < height / 2; ++h) { + memcpy(tgt, src, width / 2); + tgt += width / 2; + src += linesize; + } + } + torch::Tensor uv = torch::cat({u, v}, -1); + // Upsample width and height + uv = uv.repeat_interleave(2, -2).repeat_interleave(2, -3); + torch::Tensor t = torch::cat({y, uv}, -1); + return t.permute({0, 3, 1, 2}); // NCHW +} + torch::Tensor convert_image_tensor(AVFrame* pFrame) { // ref: // https://ffmpeg.org/doxygen/4.1/filtering__video_8c_source.html#l00179 @@ -189,6 +238,8 @@ torch::Tensor convert_image_tensor(AVFrame* pFrame) { case AV_PIX_FMT_GRAY8: channel = 1; break; + case AV_PIX_FMT_YUV420P: + return convert_yuv420p(pFrame); default: throw std::runtime_error( "Unexpected format: " + std::string(av_get_pix_fmt_name(format))); diff --git a/torchaudio/csrc/ffmpeg/prototype.cpp b/torchaudio/csrc/ffmpeg/prototype.cpp index 69a0d26991..f18d65276d 100644 --- a/torchaudio/csrc/ffmpeg/prototype.cpp +++ b/torchaudio/csrc/ffmpeg/prototype.cpp @@ -193,14 +193,29 @@ std::string get_vfilter_desc( // Check other useful formats // https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes AVPixelFormat fmt = [&]() { - std::string val = format.value(); - if (val == "RGB") - return AV_PIX_FMT_RGB24; - if (val == "BGR") - return AV_PIX_FMT_BGR24; - if (val == "GRAY") - return AV_PIX_FMT_GRAY8; - throw std::runtime_error("Unexpected format: " + val); + const std::map valid_choices { + {"RGB", AV_PIX_FMT_RGB24}, + {"BGR", AV_PIX_FMT_BGR24}, + {"YUV", AV_PIX_FMT_YUV420P}, + {"GRAY", AV_PIX_FMT_GRAY8}, + }; + + const std::string val = format.value(); + if (valid_choices.find(val) == valid_choices.end()) { + std::stringstream ss; + ss << "Unexpected output video format: \"" << val << "\"." + << "Valid choices are; "; + int i = 0; + for (const auto& p : valid_choices) { + if (i == 0) { + ss << "\"" << p.first << "\""; + } else { + ss << ", \"" << p.first << "\""; + } + } + throw std::runtime_error(ss.str()); + } + return valid_choices.at(val); }(); components.emplace_back( string_format("format=pix_fmts=%s", av_get_pix_fmt_name(fmt))); diff --git a/torchaudio/prototype/io/streamer.py b/torchaudio/prototype/io/streamer.py index 34e863a667..e94433d1ba 100644 --- a/torchaudio/prototype/io/streamer.py +++ b/torchaudio/prototype/io/streamer.py @@ -334,6 +334,7 @@ def add_basic_video_stream( - `RGB`: 8 bits * 3 channels - `BGR`: 8 bits * 3 channels + - `YUV`: 8 bits * 3 channels - `GRAY`: 8 bits * 1 channels """ i = self.default_video_stream if stream_index is None else stream_index