Skip to content

Encoding: support file-like objects #668

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/torchcodec/_core/AVIOContextHolder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ void AVIOContextHolder::createAVIOContext(
buffer != nullptr,
"Failed to allocate buffer of size " + std::to_string(bufferSize));

TORCH_CHECK(
(seek != nullptr) && ((write != nullptr) ^ (read != nullptr)),
"seek method must be defined, and either write or read must be defined. "
"But not both!")
// TORCH_CHECK(
// (seek != nullptr) && ((write != nullptr) ^ (read != nullptr)),
// "seek method must be defined, and either write or read must be
// defined. " "But not both!")
avioContext_.reset(avioAllocContext(
buffer,
bufferSize,
Expand Down
10 changes: 9 additions & 1 deletion src/torchcodec/_core/AVIOFileLikeContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ AVIOFileLikeContext::AVIOFileLikeContext(py::object fileLike)
py::hasattr(fileLike, "seek"),
"File like object must implement a seek method.");
}
createAVIOContext(&read, nullptr, &seek, &fileLike_);
createAVIOContext(&read, &write, &seek, &fileLike_);
}

int AVIOFileLikeContext::read(void* opaque, uint8_t* buf, int buf_size) {
Expand Down Expand Up @@ -77,4 +77,12 @@ int64_t AVIOFileLikeContext::seek(void* opaque, int64_t offset, int whence) {
return py::cast<int64_t>((*fileLike)->attr("seek")(offset, whence));
}

int AVIOFileLikeContext::write(void* opaque, const uint8_t* buf, int buf_size) {
auto fileLike = static_cast<UniquePyObject*>(opaque);
py::gil_scoped_acquire gil;
py::bytes bytes_obj(reinterpret_cast<const char*>(buf), buf_size);

return py::cast<int64_t>((*fileLike)->attr("write")(bytes_obj));
}

} // namespace facebook::torchcodec
1 change: 1 addition & 0 deletions src/torchcodec/_core/AVIOFileLikeContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class AVIOFileLikeContext : public AVIOContextHolder {
private:
static int read(void* opaque, uint8_t* buf, int buf_size);
static int64_t seek(void* opaque, int64_t offset, int whence);
static int write(void* opaque, const uint8_t* buf, int buf_size);

// Note that we dynamically allocate the Python object because we need to
// strictly control when its destructor is called. We must hold the GIL
Expand Down
79 changes: 56 additions & 23 deletions src/torchcodec/_core/Encoder.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <sstream>

#include "src/torchcodec/_core/AVIOBytesContext.h"
#include "src/torchcodec/_core/AVIOContextHolder.h"
#include "src/torchcodec/_core/Encoder.h"
#include "torch/types.h"

Expand All @@ -14,6 +15,18 @@ torch::Tensor validateWf(torch::Tensor wf) {
"waveform must have float32 dtype, got ",
wf.dtype());
TORCH_CHECK(wf.dim() == 2, "waveform must have 2 dimensions, got ", wf.dim());

// We enforce this, but if we get user reports we should investigate whether
// that's actually needed.
int numChannels = static_cast<int>(wf.sizes()[0]);
TORCH_CHECK(
numChannels <= AV_NUM_DATA_POINTERS,
"Trying to encode ",
numChannels,
" channels, but FFmpeg only supports ",
AV_NUM_DATA_POINTERS,
" channels per frame.");

return wf.contiguous();
}

Expand Down Expand Up @@ -136,6 +149,31 @@ AudioEncoder::AudioEncoder(
initializeEncoder(sampleRate, bitRate);
}

// TODO this sucks, shouldn't need 2 separate constructors for AVIOContextHolder
AudioEncoder::AudioEncoder(
const torch::Tensor wf,
int sampleRate,
std::string_view formatName,
std::unique_ptr<AVIOContextHolder> avioContextHolder,
std::optional<int64_t> bitRate)
: wf_(validateWf(wf)), avioContextHolderrrr_(std::move(avioContextHolder)) {
setFFmpegLogLevel();
AVFormatContext* avFormatContext = nullptr;
int status = avformat_alloc_output_context2(
&avFormatContext, nullptr, formatName.data(), nullptr);

TORCH_CHECK(
avFormatContext != nullptr,
"Couldn't allocate AVFormatContext. ",
"Check the desired extension? ",
getFFMPEGErrorStringFromErrorCode(status));
avFormatContext_.reset(avFormatContext);

avFormatContext_->pb = avioContextHolderrrr_->getAVIOContext();

initializeEncoder(sampleRate, bitRate);
}

void AudioEncoder::initializeEncoder(
int sampleRate,
std::optional<int64_t> bitRate) {
Expand Down Expand Up @@ -164,18 +202,7 @@ void AudioEncoder::initializeEncoder(
// what the `.sample_fmt` defines.
avCodecContext_->sample_fmt = findBestOutputSampleFormat(*avCodec);

int numChannels = static_cast<int>(wf_.sizes()[0]);
TORCH_CHECK(
// TODO-ENCODING is this even true / needed? We can probably support more
// with non-planar data?
numChannels <= AV_NUM_DATA_POINTERS,
"Trying to encode ",
numChannels,
" channels, but FFmpeg only supports ",
AV_NUM_DATA_POINTERS,
" channels per frame.");

setDefaultChannelLayout(avCodecContext_, numChannels);
setDefaultChannelLayout(avCodecContext_, static_cast<int>(wf_.sizes()[0]));

int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
TORCH_CHECK(
Expand Down Expand Up @@ -206,9 +233,12 @@ torch::Tensor AudioEncoder::encodeToTensor() {
}

void AudioEncoder::encode() {
// TODO-ENCODING: Need to check, but consecutive calls to encode() are
// probably invalid. We can address this once we (re)design the public and
// private encoding APIs.
// To be on the safe side we enforce that encode() can only be called once on
// an encoder object. Whether this is actually necessary is unknown, so this
// may be relaxed if needed.
TORCH_CHECK(!encodeWasCalled_, "Cannot call encode() twice.");
encodeWasCalled_ = true;

UniqueAVFrame avFrame(av_frame_alloc());
TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame.");
// Default to 256 like in torchaudio
Expand Down Expand Up @@ -322,14 +352,17 @@ void AudioEncoder::encodeInnerLoop(
ReferenceAVPacket packet(autoAVPacket);
status = avcodec_receive_packet(avCodecContext_.get(), packet.get());
if (status == AVERROR(EAGAIN) || status == AVERROR_EOF) {
// TODO-ENCODING this is from TorchAudio, probably needed, but not sure.
// if (status == AVERROR_EOF) {
// status = av_interleaved_write_frame(avFormatContext_.get(),
// nullptr); TORCH_CHECK(
// status == AVSUCCESS,
// "Failed to flush packet ",
// getFFMPEGErrorStringFromErrorCode(status));
// }
if (status == AVERROR_EOF) {
// Flush the packets that were potentially buffered by
// av_interleaved_write_frame(). See corresponding block in
// TorchAudio:
// https://github.com/pytorch/audio/blob/d60ce09e2c532d5bf2e05619e700ab520543465e/src/libtorio/ffmpeg/stream_writer/encoder.cpp#L21
status = av_interleaved_write_frame(avFormatContext_.get(), nullptr);
TORCH_CHECK(
status == AVSUCCESS,
"Failed to flush packet: ",
getFFMPEGErrorStringFromErrorCode(status));
}
return;
}
TORCH_CHECK(
Expand Down
10 changes: 10 additions & 0 deletions src/torchcodec/_core/Encoder.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once
#include <torch/types.h>
#include "src/torchcodec/_core/AVIOBytesContext.h"
#include "src/torchcodec/_core/AVIOContextHolder.h"
#include "src/torchcodec/_core/FFMPEGCommon.h"

namespace facebook::torchcodec {
Expand Down Expand Up @@ -28,6 +29,12 @@ class AudioEncoder {
std::string_view formatName,
std::unique_ptr<AVIOToTensorContext> avioContextHolder,
std::optional<int64_t> bitRate = std::nullopt);
AudioEncoder(
const torch::Tensor wf,
int sampleRate,
std::string_view formatName,
std::unique_ptr<AVIOContextHolder> avioContextHolder,
std::optional<int64_t> bitRate = std::nullopt);
void encode();
torch::Tensor encodeToTensor();

Expand All @@ -49,5 +56,8 @@ class AudioEncoder {

// Stores the AVIOContext for the output tensor buffer.
std::unique_ptr<AVIOToTensorContext> avioContextHolder_;
std::unique_ptr<AVIOContextHolder> avioContextHolderrrr_; // EWWWWW

bool encodeWasCalled_ = false;
};
} // namespace facebook::torchcodec
1 change: 1 addition & 0 deletions src/torchcodec/_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
create_from_file_like,
create_from_tensor,
encode_audio_to_file,
encode_audio_to_file_like,
encode_audio_to_tensor,
get_ffmpeg_library_versions,
get_frame_at_index,
Expand Down
2 changes: 0 additions & 2 deletions src/torchcodec/_core/custom_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,6 @@ void encode_audio_to_file(
.encode();
}

// TODO-ENCODING is "format" a good parameter name?? It kinda conflicts with
// "sample_format" which we may eventually want to expose.
at::Tensor encode_audio_to_tensor(
const at::Tensor wf,
int64_t sample_rate,
Expand Down
11 changes: 11 additions & 0 deletions src/torchcodec/_core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,17 @@ def create_from_file_like(
return _convert_to_tensor(_pybind_ops.create_from_file_like(file_like, seek_mode))


def encode_audio_to_file_like(
file_like: Union[io.RawIOBase, io.BufferedReader],
wf: torch.Tensor,
sample_rate: int,
format: str,
bit_rate: Optional[int] = None,
):
assert _pybind_ops is not None
_pybind_ops.encode_audio_to_file_like(file_like, wf, sample_rate, format, bit_rate)


# ==============================
# Abstract impl for the operators. Needed by torch.compile.
# ==============================
Expand Down
19 changes: 19 additions & 0 deletions src/torchcodec/_core/pybind_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <string>

#include "src/torchcodec/_core/AVIOFileLikeContext.h"
#include "src/torchcodec/_core/Encoder.h"
#include "src/torchcodec/_core/SingleStreamDecoder.h"

namespace py = pybind11;
Expand Down Expand Up @@ -38,8 +39,26 @@ int64_t create_from_file_like(
return reinterpret_cast<int64_t>(decoder);
}

void encode_audio_to_file_like(
py::object file_like,
// const at::Tensor wf,
[[maybe_unused]] int wf,
int64_t sample_rate,
std::string_view format,
std::optional<int64_t> bit_rate = std::nullopt) {
auto avioContextHolder = std::make_unique<AVIOFileLikeContext>(file_like);
AudioEncoder(
torch::empty({2, 1000}, torch::kFloat32),
sample_rate, // TODO need validateSampleRate
format,
std::move(avioContextHolder),
bit_rate)
.encode();
}

PYBIND11_MODULE(decoder_core_pybind_ops, m) {
m.def("create_from_file_like", &create_from_file_like);
m.def("encode_audio_to_file_like", &encode_audio_to_file_like);
}

} // namespace facebook::torchcodec
9 changes: 7 additions & 2 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,11 +1132,11 @@ def test_bad_input(self, tmp_path):

with pytest.raises(RuntimeError, match="No such file or directory"):
encode_audio_to_file(
wf=torch.rand(10, 10), sample_rate=10, filename="./bad/path.mp3"
wf=torch.rand(2, 10), sample_rate=10, filename="./bad/path.mp3"
)
with pytest.raises(RuntimeError, match="Check the desired extension"):
encode_audio_to_file(
wf=torch.rand(10, 10), sample_rate=10, filename="./file.bad_extension"
wf=torch.rand(2, 10), sample_rate=10, filename="./file.bad_extension"
)

with pytest.raises(RuntimeError, match="invalid sample rate=10"):
Expand All @@ -1153,6 +1153,11 @@ def test_bad_input(self, tmp_path):
bit_rate=-1, # bad
)

with pytest.raises(RuntimeError, match="Trying to encode 10 channels"):
encode_audio_to_file(
wf=torch.rand(10, 20), sample_rate=10, filename="doesnt_matter"
)

@pytest.mark.parametrize(
"encode_method", (encode_audio_to_file, encode_audio_to_tensor)
)
Expand Down
Loading