diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 1e0e75c3..1c876f4e 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -14,6 +14,18 @@ torch::Tensor validateWf(torch::Tensor wf) { "waveform must have float32 dtype, got ", wf.dtype()); TORCH_CHECK(wf.dim() == 2, "waveform must have 2 dimensions, got ", wf.dim()); + + // We enforce this, but if we get user reports we should investigate whether + // that's actually needed. + int numChannels = static_cast(wf.sizes()[0]); + TORCH_CHECK( + numChannels <= AV_NUM_DATA_POINTERS, + "Trying to encode ", + numChannels, + " channels, but FFmpeg only supports ", + AV_NUM_DATA_POINTERS, + " channels per frame."); + return wf.contiguous(); } @@ -164,18 +176,7 @@ void AudioEncoder::initializeEncoder( // what the `.sample_fmt` defines. avCodecContext_->sample_fmt = findBestOutputSampleFormat(*avCodec); - int numChannels = static_cast(wf_.sizes()[0]); - TORCH_CHECK( - // TODO-ENCODING is this even true / needed? We can probably support more - // with non-planar data? - numChannels <= AV_NUM_DATA_POINTERS, - "Trying to encode ", - numChannels, - " channels, but FFmpeg only supports ", - AV_NUM_DATA_POINTERS, - " channels per frame."); - - setDefaultChannelLayout(avCodecContext_, numChannels); + setDefaultChannelLayout(avCodecContext_, static_cast(wf_.sizes()[0])); int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr); TORCH_CHECK( @@ -206,9 +207,12 @@ torch::Tensor AudioEncoder::encodeToTensor() { } void AudioEncoder::encode() { - // TODO-ENCODING: Need to check, but consecutive calls to encode() are - // probably invalid. We can address this once we (re)design the public and - // private encoding APIs. + // To be on the safe side we enforce that encode() can only be called once on + // an encoder object. Whether this is actually necessary is unknown, so this + // may be relaxed if needed. + TORCH_CHECK(!encodeWasCalled_, "Cannot call encode() twice."); + encodeWasCalled_ = true; + UniqueAVFrame avFrame(av_frame_alloc()); TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame."); // Default to 256 like in torchaudio @@ -322,14 +326,17 @@ void AudioEncoder::encodeInnerLoop( ReferenceAVPacket packet(autoAVPacket); status = avcodec_receive_packet(avCodecContext_.get(), packet.get()); if (status == AVERROR(EAGAIN) || status == AVERROR_EOF) { - // TODO-ENCODING this is from TorchAudio, probably needed, but not sure. - // if (status == AVERROR_EOF) { - // status = av_interleaved_write_frame(avFormatContext_.get(), - // nullptr); TORCH_CHECK( - // status == AVSUCCESS, - // "Failed to flush packet ", - // getFFMPEGErrorStringFromErrorCode(status)); - // } + if (status == AVERROR_EOF) { + // Flush the packets that were potentially buffered by + // av_interleaved_write_frame(). See corresponding block in + // TorchAudio: + // https://github.com/pytorch/audio/blob/d60ce09e2c532d5bf2e05619e700ab520543465e/src/libtorio/ffmpeg/stream_writer/encoder.cpp#L21 + status = av_interleaved_write_frame(avFormatContext_.get(), nullptr); + TORCH_CHECK( + status == AVSUCCESS, + "Failed to flush packet: ", + getFFMPEGErrorStringFromErrorCode(status)); + } return; } TORCH_CHECK( diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index 17f09d59..bf31c31b 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -49,5 +49,7 @@ class AudioEncoder { // Stores the AVIOContext for the output tensor buffer. std::unique_ptr avioContextHolder_; + + bool encodeWasCalled_ = false; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 2f470617..813c53a7 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -394,8 +394,6 @@ void encode_audio_to_file( .encode(); } -// TODO-ENCODING is "format" a good parameter name?? It kinda conflicts with -// "sample_format" which we may eventually want to expose. at::Tensor encode_audio_to_tensor( const at::Tensor wf, int64_t sample_rate, diff --git a/test/test_ops.py b/test/test_ops.py index ddca330a..6e53d27b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1132,11 +1132,11 @@ def test_bad_input(self, tmp_path): with pytest.raises(RuntimeError, match="No such file or directory"): encode_audio_to_file( - wf=torch.rand(10, 10), sample_rate=10, filename="./bad/path.mp3" + wf=torch.rand(2, 10), sample_rate=10, filename="./bad/path.mp3" ) with pytest.raises(RuntimeError, match="Check the desired extension"): encode_audio_to_file( - wf=torch.rand(10, 10), sample_rate=10, filename="./file.bad_extension" + wf=torch.rand(2, 10), sample_rate=10, filename="./file.bad_extension" ) with pytest.raises(RuntimeError, match="invalid sample rate=10"): @@ -1153,6 +1153,11 @@ def test_bad_input(self, tmp_path): bit_rate=-1, # bad ) + with pytest.raises(RuntimeError, match="Trying to encode 10 channels"): + encode_audio_to_file( + wf=torch.rand(10, 20), sample_rate=10, filename="doesnt_matter" + ) + @pytest.mark.parametrize( "encode_method", (encode_audio_to_file, encode_audio_to_tensor) )