Skip to content

Commit 51d1135

Browse files
authored
Support encoding into a bytes tensor (#635)
1 parent b4619d7 commit 51d1135

File tree

14 files changed

+343
-103
lines changed

14 files changed

+343
-103
lines changed

src/torchcodec/_core/AVIOBytesContext.cpp

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ AVIOBytesContext::AVIOBytesContext(const void* data, int64_t dataSize)
1313
: dataContext_{static_cast<const uint8_t*>(data), dataSize, 0} {
1414
TORCH_CHECK(data != nullptr, "Video data buffer cannot be nullptr!");
1515
TORCH_CHECK(dataSize > 0, "Video data size must be positive");
16-
createAVIOContext(&read, &seek, &dataContext_);
16+
createAVIOContext(&read, nullptr, &seek, &dataContext_);
1717
}
1818

1919
// The signature of this function is defined by FFMPEG.
@@ -67,4 +67,71 @@ int64_t AVIOBytesContext::seek(void* opaque, int64_t offset, int whence) {
6767
return ret;
6868
}
6969

70+
AVIOToTensorContext::AVIOToTensorContext()
71+
: dataContext_{
72+
torch::empty(
73+
{AVIOToTensorContext::INITIAL_TENSOR_SIZE},
74+
{torch::kUInt8}),
75+
0} {
76+
createAVIOContext(nullptr, &write, &seek, &dataContext_);
77+
}
78+
79+
// The signature of this function is defined by FFMPEG.
80+
int AVIOToTensorContext::write(void* opaque, const uint8_t* buf, int buf_size) {
81+
auto dataContext = static_cast<DataContext*>(opaque);
82+
83+
int64_t bufSize = static_cast<int64_t>(buf_size);
84+
if (dataContext->current + bufSize > dataContext->outputTensor.numel()) {
85+
TORCH_CHECK(
86+
dataContext->outputTensor.numel() * 2 <=
87+
AVIOToTensorContext::MAX_TENSOR_SIZE,
88+
"We tried to allocate an output encoded tensor larger than ",
89+
AVIOToTensorContext::MAX_TENSOR_SIZE,
90+
" bytes. If you think this should be supported, please report.");
91+
92+
// We double the size of the outpout tensor. Calling cat() may not be the
93+
// most efficient, but it's simple.
94+
dataContext->outputTensor =
95+
torch::cat({dataContext->outputTensor, dataContext->outputTensor});
96+
}
97+
98+
TORCH_CHECK(
99+
dataContext->current + bufSize <= dataContext->outputTensor.numel(),
100+
"Re-allocation of the output tensor didn't work. ",
101+
"This should not happen, please report on TorchCodec bug tracker");
102+
103+
uint8_t* outputTensorData = dataContext->outputTensor.data_ptr<uint8_t>();
104+
std::memcpy(outputTensorData + dataContext->current, buf, bufSize);
105+
dataContext->current += bufSize;
106+
return buf_size;
107+
}
108+
109+
// The signature of this function is defined by FFMPEG.
110+
// Note: This `seek()` implementation is very similar to that of
111+
// AVIOBytesContext. We could consider merging both classes, or do some kind of
112+
// refac, but this doesn't seem worth it ATM.
113+
int64_t AVIOToTensorContext::seek(void* opaque, int64_t offset, int whence) {
114+
auto dataContext = static_cast<DataContext*>(opaque);
115+
int64_t ret = -1;
116+
117+
switch (whence) {
118+
case AVSEEK_SIZE:
119+
ret = dataContext->outputTensor.numel();
120+
break;
121+
case SEEK_SET:
122+
dataContext->current = offset;
123+
ret = offset;
124+
break;
125+
default:
126+
break;
127+
}
128+
129+
return ret;
130+
}
131+
132+
torch::Tensor AVIOToTensorContext::getOutputTensor() {
133+
return dataContext_.outputTensor.narrow(
134+
/*dim=*/0, /*start=*/0, /*length=*/dataContext_.current);
135+
}
136+
70137
} // namespace facebook::torchcodec

src/torchcodec/_core/AVIOBytesContext.h

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66

77
#pragma once
88

9+
#include <torch/types.h>
910
#include "src/torchcodec/_core/AVIOContextHolder.h"
1011

1112
namespace facebook::torchcodec {
1213

13-
// Enables users to pass in the entire video as bytes. Our read and seek
14-
// functions then traverse the bytes in memory.
14+
// For Decoding: enables users to pass in the entire video or audio as bytes.
15+
// Our read and seek functions then traverse the bytes in memory.
1516
class AVIOBytesContext : public AVIOContextHolder {
1617
public:
1718
explicit AVIOBytesContext(const void* data, int64_t dataSize);
@@ -29,4 +30,25 @@ class AVIOBytesContext : public AVIOContextHolder {
2930
DataContext dataContext_;
3031
};
3132

33+
// For Encoding: used to encode into an output uint8 (bytes) tensor.
34+
class AVIOToTensorContext : public AVIOContextHolder {
35+
public:
36+
explicit AVIOToTensorContext();
37+
torch::Tensor getOutputTensor();
38+
39+
private:
40+
struct DataContext {
41+
torch::Tensor outputTensor;
42+
int64_t current;
43+
};
44+
45+
static constexpr int64_t INITIAL_TENSOR_SIZE = 10'000'000; // 10MB
46+
static constexpr int64_t MAX_TENSOR_SIZE = 320'000'000; // 320 MB
47+
static int write(void* opaque, const uint8_t* buf, int buf_size);
48+
// We need to expose seek() for some formats like mp3.
49+
static int64_t seek(void* opaque, int64_t offset, int whence);
50+
51+
DataContext dataContext_;
52+
};
53+
3254
} // namespace facebook::torchcodec

src/torchcodec/_core/AVIOContextHolder.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ namespace facebook::torchcodec {
1111

1212
void AVIOContextHolder::createAVIOContext(
1313
AVIOReadFunction read,
14+
AVIOWriteFunction write,
1415
AVIOSeekFunction seek,
1516
void* heldData,
1617
int bufferSize) {
@@ -22,13 +23,17 @@ void AVIOContextHolder::createAVIOContext(
2223
buffer != nullptr,
2324
"Failed to allocate buffer of size " + std::to_string(bufferSize));
2425

25-
avioContext_.reset(avio_alloc_context(
26+
TORCH_CHECK(
27+
(seek != nullptr) && ((write != nullptr) ^ (read != nullptr)),
28+
"seek method must be defined, and either write or read must be defined. "
29+
"But not both!")
30+
avioContext_.reset(avioAllocContext(
2631
buffer,
2732
bufferSize,
28-
0,
33+
/*write_flag=*/write != nullptr,
2934
heldData,
3035
read,
31-
nullptr, // write function; not supported yet
36+
write,
3237
seek));
3338

3439
if (!avioContext_) {

src/torchcodec/_core/AVIOContextHolder.h

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ namespace facebook::torchcodec {
1919
// freed.
2020
// 2. It is a base class for AVIOContext specializations. When specializing a
2121
// AVIOContext, we need to provide four things:
22-
// 1. A read callback function.
23-
// 2. A seek callback function.
24-
// 3. A write callback function. (Not supported yet; it's for encoding.)
22+
// 1. A read callback function, for decoding.
23+
// 2. A seek callback function, for decoding and encoding.
24+
// 3. A write callback function, for encoding.
2525
// 4. A pointer to some context object that has the same lifetime as the
2626
// AVIOContext itself. This context object holds the custom state that
2727
// tracks the custom behavior of reading, seeking and writing. It is
@@ -44,13 +44,10 @@ class AVIOContextHolder {
4444
// enforced by having a pure virtual methods, but we don't have any.)
4545
AVIOContextHolder() = default;
4646

47-
// These signatures are defined by FFmpeg.
48-
using AVIOReadFunction = int (*)(void*, uint8_t*, int);
49-
using AVIOSeekFunction = int64_t (*)(void*, int64_t, int);
50-
5147
// Deriving classes should call this function in their constructor.
5248
void createAVIOContext(
5349
AVIOReadFunction read,
50+
AVIOWriteFunction write,
5451
AVIOSeekFunction seek,
5552
void* heldData,
5653
int bufferSize = defaultBufferSize);

src/torchcodec/_core/AVIOFileLikeContext.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ AVIOFileLikeContext::AVIOFileLikeContext(py::object fileLike)
2323
py::hasattr(fileLike, "seek"),
2424
"File like object must implement a seek method.");
2525
}
26-
createAVIOContext(&read, &seek, &fileLike_);
26+
createAVIOContext(&read, nullptr, &seek, &fileLike_);
2727
}
2828

2929
int AVIOFileLikeContext::read(void* opaque, uint8_t* buf, int buf_size) {

src/torchcodec/_core/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,9 @@ function(make_torchcodec_libraries
6565
set(decoder_library_name "libtorchcodec_decoder${ffmpeg_major_version}")
6666
set(decoder_sources
6767
AVIOContextHolder.cpp
68+
AVIOBytesContext.cpp
6869
FFMPEGCommon.cpp
69-
DeviceInterface.cpp
70+
DeviceInterface.cpp
7071
SingleStreamDecoder.cpp
7172
# TODO: lib name should probably not be "*_decoder*" now that it also
7273
# contains an encoder

src/torchcodec/_core/Encoder.cpp

Lines changed: 57 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
11
#include <sstream>
22

3+
#include "src/torchcodec/_core/AVIOBytesContext.h"
34
#include "src/torchcodec/_core/Encoder.h"
45
#include "torch/types.h"
56

67
namespace facebook::torchcodec {
78

89
namespace {
910

11+
torch::Tensor validateWf(torch::Tensor wf) {
12+
TORCH_CHECK(
13+
wf.dtype() == torch::kFloat32,
14+
"waveform must have float32 dtype, got ",
15+
wf.dtype());
16+
// TODO-ENCODING check contiguity of the input wf to ensure that it is indeed
17+
// planar (fltp).
18+
TORCH_CHECK(wf.dim() == 2, "waveform must have 2 dimensions, got ", wf.dim());
19+
return wf;
20+
}
21+
1022
void validateSampleRate(const AVCodec& avCodec, int sampleRate) {
1123
if (avCodec.supported_samplerates == nullptr) {
1224
return;
@@ -80,38 +92,55 @@ AudioEncoder::AudioEncoder(
8092
int sampleRate,
8193
std::string_view fileName,
8294
std::optional<int64_t> bitRate)
83-
: wf_(wf) {
84-
TORCH_CHECK(
85-
wf_.dtype() == torch::kFloat32,
86-
"waveform must have float32 dtype, got ",
87-
wf_.dtype());
88-
// TODO-ENCODING check contiguity of the input wf to ensure that it is indeed
89-
// planar (fltp).
90-
TORCH_CHECK(
91-
wf_.dim() == 2, "waveform must have 2 dimensions, got ", wf_.dim());
92-
95+
: wf_(validateWf(wf)) {
9396
setFFmpegLogLevel();
9497
AVFormatContext* avFormatContext = nullptr;
95-
auto status = avformat_alloc_output_context2(
98+
int status = avformat_alloc_output_context2(
9699
&avFormatContext, nullptr, nullptr, fileName.data());
100+
97101
TORCH_CHECK(
98102
avFormatContext != nullptr,
99103
"Couldn't allocate AVFormatContext. ",
100104
"Check the desired extension? ",
101105
getFFMPEGErrorStringFromErrorCode(status));
102106
avFormatContext_.reset(avFormatContext);
103107

104-
// TODO-ENCODING: Should also support encoding into bytes (use
105-
// AVIOBytesContext)
106-
TORCH_CHECK(
107-
!(avFormatContext->oformat->flags & AVFMT_NOFILE),
108-
"AVFMT_NOFILE is set. We only support writing to a file.");
109108
status = avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE);
110109
TORCH_CHECK(
111110
status >= 0,
112111
"avio_open failed: ",
113112
getFFMPEGErrorStringFromErrorCode(status));
114113

114+
initializeEncoder(sampleRate, bitRate);
115+
}
116+
117+
AudioEncoder::AudioEncoder(
118+
const torch::Tensor wf,
119+
int sampleRate,
120+
std::string_view formatName,
121+
std::unique_ptr<AVIOToTensorContext> avioContextHolder,
122+
std::optional<int64_t> bitRate)
123+
: wf_(validateWf(wf)), avioContextHolder_(std::move(avioContextHolder)) {
124+
setFFmpegLogLevel();
125+
AVFormatContext* avFormatContext = nullptr;
126+
int status = avformat_alloc_output_context2(
127+
&avFormatContext, nullptr, formatName.data(), nullptr);
128+
129+
TORCH_CHECK(
130+
avFormatContext != nullptr,
131+
"Couldn't allocate AVFormatContext. ",
132+
"Check the desired extension? ",
133+
getFFMPEGErrorStringFromErrorCode(status));
134+
avFormatContext_.reset(avFormatContext);
135+
136+
avFormatContext_->pb = avioContextHolder_->getAVIOContext();
137+
138+
initializeEncoder(sampleRate, bitRate);
139+
}
140+
141+
void AudioEncoder::initializeEncoder(
142+
int sampleRate,
143+
std::optional<int64_t> bitRate) {
115144
// We use the AVFormatContext's default codec for that
116145
// specific format/container.
117146
const AVCodec* avCodec =
@@ -150,7 +179,7 @@ AudioEncoder::AudioEncoder(
150179

151180
setDefaultChannelLayout(avCodecContext_, numChannels);
152181

153-
status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
182+
int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
154183
TORCH_CHECK(
155184
status == AVSUCCESS,
156185
"avcodec_open2 failed: ",
@@ -170,7 +199,18 @@ AudioEncoder::AudioEncoder(
170199
streamIndex_ = avStream->index;
171200
}
172201

202+
torch::Tensor AudioEncoder::encodeToTensor() {
203+
TORCH_CHECK(
204+
avioContextHolder_ != nullptr,
205+
"Cannot encode to tensor, avio context doesn't exist.");
206+
encode();
207+
return avioContextHolder_->getOutputTensor();
208+
}
209+
173210
void AudioEncoder::encode() {
211+
// TODO-ENCODING: Need to check, but consecutive calls to encode() are
212+
// probably invalid. We can address this once we (re)design the public and
213+
// private encoding APIs.
174214
UniqueAVFrame avFrame(av_frame_alloc());
175215
TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame.");
176216
// Default to 256 like in torchaudio

src/torchcodec/_core/Encoder.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22
#include <torch/types.h>
3+
#include "src/torchcodec/_core/AVIOBytesContext.h"
34
#include "src/torchcodec/_core/FFMPEGCommon.h"
45

56
namespace facebook::torchcodec {
@@ -21,9 +22,19 @@ class AudioEncoder {
2122
int sampleRate,
2223
std::string_view fileName,
2324
std::optional<int64_t> bitRate = std::nullopt);
25+
AudioEncoder(
26+
const torch::Tensor wf,
27+
int sampleRate,
28+
std::string_view formatName,
29+
std::unique_ptr<AVIOToTensorContext> avioContextHolder,
30+
std::optional<int64_t> bitRate = std::nullopt);
2431
void encode();
32+
torch::Tensor encodeToTensor();
2533

2634
private:
35+
void initializeEncoder(
36+
int sampleRate,
37+
std::optional<int64_t> bitRate = std::nullopt);
2738
void encodeInnerLoop(
2839
AutoAVPacket& autoAVPacket,
2940
const UniqueAVFrame& srcAVFrame);
@@ -35,5 +46,8 @@ class AudioEncoder {
3546
UniqueSwrContext swrContext_;
3647

3748
const torch::Tensor wf_;
49+
50+
// Stores the AVIOContext for the output tensor buffer.
51+
std::unique_ptr<AVIOToTensorContext> avioContextHolder_;
3852
};
3953
} // namespace facebook::torchcodec

src/torchcodec/_core/FFMPEGCommon.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,4 +261,27 @@ void setFFmpegLogLevel() {
261261
av_log_set_level(logLevel);
262262
}
263263

264+
AVIOContext* avioAllocContext(
265+
uint8_t* buffer,
266+
int buffer_size,
267+
int write_flag,
268+
void* opaque,
269+
AVIOReadFunction read_packet,
270+
AVIOWriteFunction write_packet,
271+
AVIOSeekFunction seek) {
272+
return avio_alloc_context(
273+
buffer,
274+
buffer_size,
275+
write_flag,
276+
opaque,
277+
read_packet,
278+
// The buf parameter of the write function is not const before FFmpeg 7.
279+
#if LIBAVFILTER_VERSION_MAJOR >= 10 // FFmpeg >= 7
280+
write_packet,
281+
#else
282+
reinterpret_cast<AVIOWriteFunctionOld>(write_packet),
283+
#endif
284+
seek);
285+
}
286+
264287
} // namespace facebook::torchcodec

0 commit comments

Comments
 (0)