Skip to content

Commit

Permalink
Add YUV420P format support to Streamer API (#2334)
Browse files Browse the repository at this point in the history
Summary:
This commit adds YUV420P format support to Streamer API.
When the native format of a video is YUV420P, the Streamer will
output Tensor of YUV color channel.

Pull Request resolved: #2334

Reviewed By: hwangjeff

Differential Revision: D35632916

Pulled By: mthrok

fbshipit-source-id: a7a0078788433060266b8bd3e7cad023f41389f5
  • Loading branch information
mthrok authored and facebook-github-bot committed Apr 14, 2022
1 parent c262758 commit 2f70e2f
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 8 deletions.
51 changes: 51 additions & 0 deletions torchaudio/csrc/ffmpeg/buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,55 @@ void AudioBuffer::push_frame(AVFrame* frame) {
// Modifiers - Push Video
////////////////////////////////////////////////////////////////////////////////
namespace {
torch::Tensor convert_yuv420p(AVFrame* pFrame) {
int width = pFrame->width;
int height = pFrame->height;

auto options = torch::TensorOptions()
.dtype(torch::kUInt8)
.layout(torch::kStrided)
.device(torch::kCPU);

torch::Tensor y = torch::empty({1, height, width, 1}, options);
{
uint8_t* tgt = y.data_ptr<uint8_t>();
uint8_t* src = pFrame->data[0];
int linesize = pFrame->linesize[0];
for (int h = 0; h < height; ++h) {
memcpy(tgt, src, width);
tgt += width;
src += linesize;
}
}
torch::Tensor u = torch::empty({1, height / 2, width / 2, 1}, options);
{
uint8_t* tgt = u.data_ptr<uint8_t>();
uint8_t* src = pFrame->data[1];
int linesize = pFrame->linesize[1];
for (int h = 0; h < height / 2; ++h) {
memcpy(tgt, src, width / 2);
tgt += width / 2;
src += linesize;
}
}
torch::Tensor v = torch::empty({1, height / 2, width / 2, 1}, options);
{
uint8_t* tgt = v.data_ptr<uint8_t>();
uint8_t* src = pFrame->data[2];
int linesize = pFrame->linesize[2];
for (int h = 0; h < height / 2; ++h) {
memcpy(tgt, src, width / 2);
tgt += width / 2;
src += linesize;
}
}
torch::Tensor uv = torch::cat({u, v}, -1);
// Upsample width and height
uv = uv.repeat_interleave(2, -2).repeat_interleave(2, -3);
torch::Tensor t = torch::cat({y, uv}, -1);
return t.permute({0, 3, 1, 2}); // NCHW
}

torch::Tensor convert_image_tensor(AVFrame* pFrame) {
// ref:
// https://ffmpeg.org/doxygen/4.1/filtering__video_8c_source.html#l00179
Expand All @@ -189,6 +238,8 @@ torch::Tensor convert_image_tensor(AVFrame* pFrame) {
case AV_PIX_FMT_GRAY8:
channel = 1;
break;
case AV_PIX_FMT_YUV420P:
return convert_yuv420p(pFrame);
default:
throw std::runtime_error(
"Unexpected format: " + std::string(av_get_pix_fmt_name(format)));
Expand Down
31 changes: 23 additions & 8 deletions torchaudio/csrc/ffmpeg/prototype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,29 @@ std::string get_vfilter_desc(
// Check other useful formats
// https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes
AVPixelFormat fmt = [&]() {
std::string val = format.value();
if (val == "RGB")
return AV_PIX_FMT_RGB24;
if (val == "BGR")
return AV_PIX_FMT_BGR24;
if (val == "GRAY")
return AV_PIX_FMT_GRAY8;
throw std::runtime_error("Unexpected format: " + val);
const std::map<const std::string, enum AVPixelFormat> valid_choices {
{"RGB", AV_PIX_FMT_RGB24},
{"BGR", AV_PIX_FMT_BGR24},
{"YUV", AV_PIX_FMT_YUV420P},
{"GRAY", AV_PIX_FMT_GRAY8},
};

const std::string val = format.value();
if (valid_choices.find(val) == valid_choices.end()) {
std::stringstream ss;
ss << "Unexpected output video format: \"" << val << "\"."
<< "Valid choices are; ";
int i = 0;
for (const auto& p : valid_choices) {
if (i == 0) {
ss << "\"" << p.first << "\"";
} else {
ss << ", \"" << p.first << "\"";
}
}
throw std::runtime_error(ss.str());
}
return valid_choices.at(val);
}();
components.emplace_back(
string_format("format=pix_fmts=%s", av_get_pix_fmt_name(fmt)));
Expand Down
1 change: 1 addition & 0 deletions torchaudio/prototype/io/streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def add_basic_video_stream(
- `RGB`: 8 bits * 3 channels
- `BGR`: 8 bits * 3 channels
- `YUV`: 8 bits * 3 channels
- `GRAY`: 8 bits * 1 channels
"""
i = self.default_video_stream if stream_index is None else stream_index
Expand Down

0 comments on commit 2f70e2f

Please sign in to comment.