11#include < sstream>
22
3+ #include " src/torchcodec/_core/AVIOBytesContext.h"
34#include " src/torchcodec/_core/Encoder.h"
45#include " torch/types.h"
56
67namespace facebook ::torchcodec {
78
89namespace {
910
11+ torch::Tensor validateWf (torch::Tensor wf) {
12+ TORCH_CHECK (
13+ wf.dtype () == torch::kFloat32 ,
14+ " waveform must have float32 dtype, got " ,
15+ wf.dtype ());
16+ // TODO-ENCODING check contiguity of the input wf to ensure that it is indeed
17+ // planar (fltp).
18+ TORCH_CHECK (wf.dim () == 2 , " waveform must have 2 dimensions, got " , wf.dim ());
19+ return wf;
20+ }
21+
1022void validateSampleRate (const AVCodec& avCodec, int sampleRate) {
1123 if (avCodec.supported_samplerates == nullptr ) {
1224 return ;
@@ -80,38 +92,55 @@ AudioEncoder::AudioEncoder(
8092 int sampleRate,
8193 std::string_view fileName,
8294 std::optional<int64_t > bitRate)
83- : wf_(wf) {
84- TORCH_CHECK (
85- wf_.dtype () == torch::kFloat32 ,
86- " waveform must have float32 dtype, got " ,
87- wf_.dtype ());
88- // TODO-ENCODING check contiguity of the input wf to ensure that it is indeed
89- // planar (fltp).
90- TORCH_CHECK (
91- wf_.dim () == 2 , " waveform must have 2 dimensions, got " , wf_.dim ());
92-
95+ : wf_(validateWf(wf)) {
9396 setFFmpegLogLevel ();
9497 AVFormatContext* avFormatContext = nullptr ;
95- auto status = avformat_alloc_output_context2 (
98+ int status = avformat_alloc_output_context2 (
9699 &avFormatContext, nullptr , nullptr , fileName.data ());
100+
97101 TORCH_CHECK (
98102 avFormatContext != nullptr ,
99103 " Couldn't allocate AVFormatContext. " ,
100104 " Check the desired extension? " ,
101105 getFFMPEGErrorStringFromErrorCode (status));
102106 avFormatContext_.reset (avFormatContext);
103107
104- // TODO-ENCODING: Should also support encoding into bytes (use
105- // AVIOBytesContext)
106- TORCH_CHECK (
107- !(avFormatContext->oformat ->flags & AVFMT_NOFILE),
108- " AVFMT_NOFILE is set. We only support writing to a file." );
109108 status = avio_open (&avFormatContext_->pb , fileName.data (), AVIO_FLAG_WRITE);
110109 TORCH_CHECK (
111110 status >= 0 ,
112111 " avio_open failed: " ,
113112 getFFMPEGErrorStringFromErrorCode (status));
114113
114+ initializeEncoder (sampleRate, bitRate);
115+ }
116+
117+ AudioEncoder::AudioEncoder (
118+ const torch::Tensor wf,
119+ int sampleRate,
120+ std::string_view formatName,
121+ std::unique_ptr<AVIOToTensorContext> avioContextHolder,
122+ std::optional<int64_t > bitRate)
123+ : wf_(validateWf(wf)), avioContextHolder_(std::move(avioContextHolder)) {
124+ setFFmpegLogLevel ();
125+ AVFormatContext* avFormatContext = nullptr ;
126+ int status = avformat_alloc_output_context2 (
127+ &avFormatContext, nullptr , formatName.data (), nullptr );
128+
129+ TORCH_CHECK (
130+ avFormatContext != nullptr ,
131+ " Couldn't allocate AVFormatContext. " ,
132+ " Check the desired extension? " ,
133+ getFFMPEGErrorStringFromErrorCode (status));
134+ avFormatContext_.reset (avFormatContext);
135+
136+ avFormatContext_->pb = avioContextHolder_->getAVIOContext ();
137+
138+ initializeEncoder (sampleRate, bitRate);
139+ }
140+
141+ void AudioEncoder::initializeEncoder (
142+ int sampleRate,
143+ std::optional<int64_t > bitRate) {
115144 // We use the AVFormatContext's default codec for that
116145 // specific format/container.
117146 const AVCodec* avCodec =
@@ -150,7 +179,7 @@ AudioEncoder::AudioEncoder(
150179
151180 setDefaultChannelLayout (avCodecContext_, numChannels);
152181
153- status = avcodec_open2 (avCodecContext_.get (), avCodec, nullptr );
182+ int status = avcodec_open2 (avCodecContext_.get (), avCodec, nullptr );
154183 TORCH_CHECK (
155184 status == AVSUCCESS,
156185 " avcodec_open2 failed: " ,
@@ -170,7 +199,18 @@ AudioEncoder::AudioEncoder(
170199 streamIndex_ = avStream->index ;
171200}
172201
202+ torch::Tensor AudioEncoder::encodeToTensor () {
203+ TORCH_CHECK (
204+ avioContextHolder_ != nullptr ,
205+ " Cannot encode to tensor, avio context doesn't exist." );
206+ encode ();
207+ return avioContextHolder_->getOutputTensor ();
208+ }
209+
173210void AudioEncoder::encode () {
211+ // TODO-ENCODING: Need to check, but consecutive calls to encode() are
212+ // probably invalid. We can address this once we (re)design the public and
213+ // private encoding APIs.
174214 UniqueAVFrame avFrame (av_frame_alloc ());
175215 TORCH_CHECK (avFrame != nullptr , " Couldn't allocate AVFrame." );
176216 // Default to 256 like in torchaudio
0 commit comments