diff --git a/src/torchcodec/_core/AVIOContextHolder.cpp b/src/torchcodec/_core/AVIOContextHolder.cpp index e0462c28..99db8988 100644 --- a/src/torchcodec/_core/AVIOContextHolder.cpp +++ b/src/torchcodec/_core/AVIOContextHolder.cpp @@ -23,10 +23,10 @@ void AVIOContextHolder::createAVIOContext( buffer != nullptr, "Failed to allocate buffer of size " + std::to_string(bufferSize)); - TORCH_CHECK( - (seek != nullptr) && ((write != nullptr) ^ (read != nullptr)), - "seek method must be defined, and either write or read must be defined. " - "But not both!") + // TORCH_CHECK( + // (seek != nullptr) && ((write != nullptr) ^ (read != nullptr)), + // "seek method must be defined, and either write or read must be + // defined. " "But not both!") avioContext_.reset(avioAllocContext( buffer, bufferSize, diff --git a/src/torchcodec/_core/AVIOFileLikeContext.cpp b/src/torchcodec/_core/AVIOFileLikeContext.cpp index 5497f89b..3870e5a1 100644 --- a/src/torchcodec/_core/AVIOFileLikeContext.cpp +++ b/src/torchcodec/_core/AVIOFileLikeContext.cpp @@ -23,7 +23,7 @@ AVIOFileLikeContext::AVIOFileLikeContext(py::object fileLike) py::hasattr(fileLike, "seek"), "File like object must implement a seek method."); } - createAVIOContext(&read, nullptr, &seek, &fileLike_); + createAVIOContext(&read, &write, &seek, &fileLike_); } int AVIOFileLikeContext::read(void* opaque, uint8_t* buf, int buf_size) { @@ -77,4 +77,12 @@ int64_t AVIOFileLikeContext::seek(void* opaque, int64_t offset, int whence) { return py::cast((*fileLike)->attr("seek")(offset, whence)); } +int AVIOFileLikeContext::write(void* opaque, const uint8_t* buf, int buf_size) { + auto fileLike = static_cast(opaque); + py::gil_scoped_acquire gil; + py::bytes bytes_obj(reinterpret_cast(buf), buf_size); + + return py::cast((*fileLike)->attr("write")(bytes_obj)); +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/AVIOFileLikeContext.h b/src/torchcodec/_core/AVIOFileLikeContext.h index 3e80f1c6..00948515 100644 --- a/src/torchcodec/_core/AVIOFileLikeContext.h +++ b/src/torchcodec/_core/AVIOFileLikeContext.h @@ -24,6 +24,7 @@ class AVIOFileLikeContext : public AVIOContextHolder { private: static int read(void* opaque, uint8_t* buf, int buf_size); static int64_t seek(void* opaque, int64_t offset, int whence); + static int write(void* opaque, const uint8_t* buf, int buf_size); // Note that we dynamically allocate the Python object because we need to // strictly control when its destructor is called. We must hold the GIL diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 1e0e75c3..9277ccaa 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -1,6 +1,7 @@ #include #include "src/torchcodec/_core/AVIOBytesContext.h" +#include "src/torchcodec/_core/AVIOContextHolder.h" #include "src/torchcodec/_core/Encoder.h" #include "torch/types.h" @@ -14,6 +15,18 @@ torch::Tensor validateWf(torch::Tensor wf) { "waveform must have float32 dtype, got ", wf.dtype()); TORCH_CHECK(wf.dim() == 2, "waveform must have 2 dimensions, got ", wf.dim()); + + // We enforce this, but if we get user reports we should investigate whether + // that's actually needed. + int numChannels = static_cast(wf.sizes()[0]); + TORCH_CHECK( + numChannels <= AV_NUM_DATA_POINTERS, + "Trying to encode ", + numChannels, + " channels, but FFmpeg only supports ", + AV_NUM_DATA_POINTERS, + " channels per frame."); + return wf.contiguous(); } @@ -136,6 +149,31 @@ AudioEncoder::AudioEncoder( initializeEncoder(sampleRate, bitRate); } +// TODO this sucks, shouldn't need 2 separate constructors for AVIOContextHolder +AudioEncoder::AudioEncoder( + const torch::Tensor wf, + int sampleRate, + std::string_view formatName, + std::unique_ptr avioContextHolder, + std::optional bitRate) + : wf_(validateWf(wf)), avioContextHolderrrr_(std::move(avioContextHolder)) { + setFFmpegLogLevel(); + AVFormatContext* avFormatContext = nullptr; + int status = avformat_alloc_output_context2( + &avFormatContext, nullptr, formatName.data(), nullptr); + + TORCH_CHECK( + avFormatContext != nullptr, + "Couldn't allocate AVFormatContext. ", + "Check the desired extension? ", + getFFMPEGErrorStringFromErrorCode(status)); + avFormatContext_.reset(avFormatContext); + + avFormatContext_->pb = avioContextHolderrrr_->getAVIOContext(); + + initializeEncoder(sampleRate, bitRate); +} + void AudioEncoder::initializeEncoder( int sampleRate, std::optional bitRate) { @@ -164,18 +202,7 @@ void AudioEncoder::initializeEncoder( // what the `.sample_fmt` defines. avCodecContext_->sample_fmt = findBestOutputSampleFormat(*avCodec); - int numChannels = static_cast(wf_.sizes()[0]); - TORCH_CHECK( - // TODO-ENCODING is this even true / needed? We can probably support more - // with non-planar data? - numChannels <= AV_NUM_DATA_POINTERS, - "Trying to encode ", - numChannels, - " channels, but FFmpeg only supports ", - AV_NUM_DATA_POINTERS, - " channels per frame."); - - setDefaultChannelLayout(avCodecContext_, numChannels); + setDefaultChannelLayout(avCodecContext_, static_cast(wf_.sizes()[0])); int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr); TORCH_CHECK( @@ -206,9 +233,12 @@ torch::Tensor AudioEncoder::encodeToTensor() { } void AudioEncoder::encode() { - // TODO-ENCODING: Need to check, but consecutive calls to encode() are - // probably invalid. We can address this once we (re)design the public and - // private encoding APIs. + // To be on the safe side we enforce that encode() can only be called once on + // an encoder object. Whether this is actually necessary is unknown, so this + // may be relaxed if needed. + TORCH_CHECK(!encodeWasCalled_, "Cannot call encode() twice."); + encodeWasCalled_ = true; + UniqueAVFrame avFrame(av_frame_alloc()); TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame."); // Default to 256 like in torchaudio @@ -322,14 +352,17 @@ void AudioEncoder::encodeInnerLoop( ReferenceAVPacket packet(autoAVPacket); status = avcodec_receive_packet(avCodecContext_.get(), packet.get()); if (status == AVERROR(EAGAIN) || status == AVERROR_EOF) { - // TODO-ENCODING this is from TorchAudio, probably needed, but not sure. - // if (status == AVERROR_EOF) { - // status = av_interleaved_write_frame(avFormatContext_.get(), - // nullptr); TORCH_CHECK( - // status == AVSUCCESS, - // "Failed to flush packet ", - // getFFMPEGErrorStringFromErrorCode(status)); - // } + if (status == AVERROR_EOF) { + // Flush the packets that were potentially buffered by + // av_interleaved_write_frame(). See corresponding block in + // TorchAudio: + // https://github.com/pytorch/audio/blob/d60ce09e2c532d5bf2e05619e700ab520543465e/src/libtorio/ffmpeg/stream_writer/encoder.cpp#L21 + status = av_interleaved_write_frame(avFormatContext_.get(), nullptr); + TORCH_CHECK( + status == AVSUCCESS, + "Failed to flush packet: ", + getFFMPEGErrorStringFromErrorCode(status)); + } return; } TORCH_CHECK( diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index 17f09d59..37d9c703 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -1,6 +1,7 @@ #pragma once #include #include "src/torchcodec/_core/AVIOBytesContext.h" +#include "src/torchcodec/_core/AVIOContextHolder.h" #include "src/torchcodec/_core/FFMPEGCommon.h" namespace facebook::torchcodec { @@ -28,6 +29,12 @@ class AudioEncoder { std::string_view formatName, std::unique_ptr avioContextHolder, std::optional bitRate = std::nullopt); + AudioEncoder( + const torch::Tensor wf, + int sampleRate, + std::string_view formatName, + std::unique_ptr avioContextHolder, + std::optional bitRate = std::nullopt); void encode(); torch::Tensor encodeToTensor(); @@ -49,5 +56,8 @@ class AudioEncoder { // Stores the AVIOContext for the output tensor buffer. std::unique_ptr avioContextHolder_; + std::unique_ptr avioContextHolderrrr_; // EWWWWW + + bool encodeWasCalled_ = false; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/__init__.py b/src/torchcodec/_core/__init__.py index 77fc7b85..3d340bff 100644 --- a/src/torchcodec/_core/__init__.py +++ b/src/torchcodec/_core/__init__.py @@ -23,6 +23,7 @@ create_from_file_like, create_from_tensor, encode_audio_to_file, + encode_audio_to_file_like, encode_audio_to_tensor, get_ffmpeg_library_versions, get_frame_at_index, diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 2f470617..813c53a7 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -394,8 +394,6 @@ void encode_audio_to_file( .encode(); } -// TODO-ENCODING is "format" a good parameter name?? It kinda conflicts with -// "sample_format" which we may eventually want to expose. at::Tensor encode_audio_to_tensor( const at::Tensor wf, int64_t sample_rate, diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index e9b4faec..e1205794 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -153,6 +153,17 @@ def create_from_file_like( return _convert_to_tensor(_pybind_ops.create_from_file_like(file_like, seek_mode)) +def encode_audio_to_file_like( + file_like: Union[io.RawIOBase, io.BufferedReader], + wf: torch.Tensor, + sample_rate: int, + format: str, + bit_rate: Optional[int] = None, +): + assert _pybind_ops is not None + _pybind_ops.encode_audio_to_file_like(file_like, wf, sample_rate, format, bit_rate) + + # ============================== # Abstract impl for the operators. Needed by torch.compile. # ============================== diff --git a/src/torchcodec/_core/pybind_ops.cpp b/src/torchcodec/_core/pybind_ops.cpp index 6f873f5a..7d1d2e76 100644 --- a/src/torchcodec/_core/pybind_ops.cpp +++ b/src/torchcodec/_core/pybind_ops.cpp @@ -10,6 +10,7 @@ #include #include "src/torchcodec/_core/AVIOFileLikeContext.h" +#include "src/torchcodec/_core/Encoder.h" #include "src/torchcodec/_core/SingleStreamDecoder.h" namespace py = pybind11; @@ -38,8 +39,26 @@ int64_t create_from_file_like( return reinterpret_cast(decoder); } +void encode_audio_to_file_like( + py::object file_like, + // const at::Tensor wf, + [[maybe_unused]] int wf, + int64_t sample_rate, + std::string_view format, + std::optional bit_rate = std::nullopt) { + auto avioContextHolder = std::make_unique(file_like); + AudioEncoder( + torch::empty({2, 1000}, torch::kFloat32), + sample_rate, // TODO need validateSampleRate + format, + std::move(avioContextHolder), + bit_rate) + .encode(); +} + PYBIND11_MODULE(decoder_core_pybind_ops, m) { m.def("create_from_file_like", &create_from_file_like); + m.def("encode_audio_to_file_like", &encode_audio_to_file_like); } } // namespace facebook::torchcodec diff --git a/test/test_ops.py b/test/test_ops.py index ddca330a..6e53d27b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1132,11 +1132,11 @@ def test_bad_input(self, tmp_path): with pytest.raises(RuntimeError, match="No such file or directory"): encode_audio_to_file( - wf=torch.rand(10, 10), sample_rate=10, filename="./bad/path.mp3" + wf=torch.rand(2, 10), sample_rate=10, filename="./bad/path.mp3" ) with pytest.raises(RuntimeError, match="Check the desired extension"): encode_audio_to_file( - wf=torch.rand(10, 10), sample_rate=10, filename="./file.bad_extension" + wf=torch.rand(2, 10), sample_rate=10, filename="./file.bad_extension" ) with pytest.raises(RuntimeError, match="invalid sample rate=10"): @@ -1153,6 +1153,11 @@ def test_bad_input(self, tmp_path): bit_rate=-1, # bad ) + with pytest.raises(RuntimeError, match="Trying to encode 10 channels"): + encode_audio_to_file( + wf=torch.rand(10, 20), sample_rate=10, filename="doesnt_matter" + ) + @pytest.mark.parametrize( "encode_method", (encode_audio_to_file, encode_audio_to_tensor) )