Skip to content

Encoding: address some TODOs #667

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 30 additions & 23 deletions src/torchcodec/_core/Encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@ torch::Tensor validateWf(torch::Tensor wf) {
"waveform must have float32 dtype, got ",
wf.dtype());
TORCH_CHECK(wf.dim() == 2, "waveform must have 2 dimensions, got ", wf.dim());

// We enforce this, but if we get user reports we should investigate whether
// that's actually needed.
int numChannels = static_cast<int>(wf.sizes()[0]);
TORCH_CHECK(
numChannels <= AV_NUM_DATA_POINTERS,
"Trying to encode ",
numChannels,
" channels, but FFmpeg only supports ",
AV_NUM_DATA_POINTERS,
" channels per frame.");

return wf.contiguous();
}

Expand Down Expand Up @@ -164,18 +176,7 @@ void AudioEncoder::initializeEncoder(
// what the `.sample_fmt` defines.
avCodecContext_->sample_fmt = findBestOutputSampleFormat(*avCodec);

int numChannels = static_cast<int>(wf_.sizes()[0]);
TORCH_CHECK(
// TODO-ENCODING is this even true / needed? We can probably support more
// with non-planar data?
numChannels <= AV_NUM_DATA_POINTERS,
"Trying to encode ",
numChannels,
" channels, but FFmpeg only supports ",
AV_NUM_DATA_POINTERS,
" channels per frame.");

setDefaultChannelLayout(avCodecContext_, numChannels);
setDefaultChannelLayout(avCodecContext_, static_cast<int>(wf_.sizes()[0]));

int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
TORCH_CHECK(
Expand Down Expand Up @@ -206,9 +207,12 @@ torch::Tensor AudioEncoder::encodeToTensor() {
}

void AudioEncoder::encode() {
// TODO-ENCODING: Need to check, but consecutive calls to encode() are
// probably invalid. We can address this once we (re)design the public and
// private encoding APIs.
// To be on the safe side we enforce that encode() can only be called once on
// an encoder object. Whether this is actually necessary is unknown, so this
// may be relaxed if needed.
TORCH_CHECK(!encodeWasCalled_, "Cannot call encode() twice.");
encodeWasCalled_ = true;

UniqueAVFrame avFrame(av_frame_alloc());
TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame.");
// Default to 256 like in torchaudio
Expand Down Expand Up @@ -322,14 +326,17 @@ void AudioEncoder::encodeInnerLoop(
ReferenceAVPacket packet(autoAVPacket);
status = avcodec_receive_packet(avCodecContext_.get(), packet.get());
if (status == AVERROR(EAGAIN) || status == AVERROR_EOF) {
// TODO-ENCODING this is from TorchAudio, probably needed, but not sure.
// if (status == AVERROR_EOF) {
// status = av_interleaved_write_frame(avFormatContext_.get(),
// nullptr); TORCH_CHECK(
// status == AVSUCCESS,
// "Failed to flush packet ",
// getFFMPEGErrorStringFromErrorCode(status));
// }
if (status == AVERROR_EOF) {
// Flush the packets that were potentially buffered by
// av_interleaved_write_frame(). See corresponding block in
// TorchAudio:
// https://github.com/pytorch/audio/blob/d60ce09e2c532d5bf2e05619e700ab520543465e/src/libtorio/ffmpeg/stream_writer/encoder.cpp#L21
status = av_interleaved_write_frame(avFormatContext_.get(), nullptr);
TORCH_CHECK(
status == AVSUCCESS,
"Failed to flush packet: ",
getFFMPEGErrorStringFromErrorCode(status));
}
return;
}
TORCH_CHECK(
Expand Down
2 changes: 2 additions & 0 deletions src/torchcodec/_core/Encoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,7 @@ class AudioEncoder {

// Stores the AVIOContext for the output tensor buffer.
std::unique_ptr<AVIOToTensorContext> avioContextHolder_;

bool encodeWasCalled_ = false;
};
} // namespace facebook::torchcodec
2 changes: 0 additions & 2 deletions src/torchcodec/_core/custom_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,6 @@ void encode_audio_to_file(
.encode();
}

// TODO-ENCODING is "format" a good parameter name?? It kinda conflicts with
// "sample_format" which we may eventually want to expose.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torchaudio.load() uses format as well, so I think that's fine. We haven't discussed the public API yet anyway.

at::Tensor encode_audio_to_tensor(
const at::Tensor wf,
int64_t sample_rate,
Expand Down
9 changes: 7 additions & 2 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,11 +1132,11 @@ def test_bad_input(self, tmp_path):

with pytest.raises(RuntimeError, match="No such file or directory"):
encode_audio_to_file(
wf=torch.rand(10, 10), sample_rate=10, filename="./bad/path.mp3"
wf=torch.rand(2, 10), sample_rate=10, filename="./bad/path.mp3"
)
with pytest.raises(RuntimeError, match="Check the desired extension"):
encode_audio_to_file(
wf=torch.rand(10, 10), sample_rate=10, filename="./file.bad_extension"
wf=torch.rand(2, 10), sample_rate=10, filename="./file.bad_extension"
)

with pytest.raises(RuntimeError, match="invalid sample rate=10"):
Expand All @@ -1153,6 +1153,11 @@ def test_bad_input(self, tmp_path):
bit_rate=-1, # bad
)

with pytest.raises(RuntimeError, match="Trying to encode 10 channels"):
encode_audio_to_file(
wf=torch.rand(10, 20), sample_rate=10, filename="doesnt_matter"
)

@pytest.mark.parametrize(
"encode_method", (encode_audio_to_file, encode_audio_to_tensor)
)
Expand Down
Loading