@@ -14,6 +14,18 @@ torch::Tensor validateWf(torch::Tensor wf) {
1414 " waveform must have float32 dtype, got " ,
1515 wf.dtype ());
1616 TORCH_CHECK (wf.dim () == 2 , " waveform must have 2 dimensions, got " , wf.dim ());
17+
18+ // We enforce this, but if we get user reports we should investigate whether
19+ // that's actually needed.
20+ int numChannels = static_cast <int >(wf.sizes ()[0 ]);
21+ TORCH_CHECK (
22+ numChannels <= AV_NUM_DATA_POINTERS,
23+ " Trying to encode " ,
24+ numChannels,
25+ " channels, but FFmpeg only supports " ,
26+ AV_NUM_DATA_POINTERS,
27+ " channels per frame." );
28+
1729 return wf.contiguous ();
1830}
1931
@@ -164,18 +176,7 @@ void AudioEncoder::initializeEncoder(
164176 // what the `.sample_fmt` defines.
165177 avCodecContext_->sample_fmt = findBestOutputSampleFormat (*avCodec);
166178
167- int numChannels = static_cast <int >(wf_.sizes ()[0 ]);
168- TORCH_CHECK (
169- // TODO-ENCODING is this even true / needed? We can probably support more
170- // with non-planar data?
171- numChannels <= AV_NUM_DATA_POINTERS,
172- " Trying to encode " ,
173- numChannels,
174- " channels, but FFmpeg only supports " ,
175- AV_NUM_DATA_POINTERS,
176- " channels per frame." );
177-
178- setDefaultChannelLayout (avCodecContext_, numChannels);
179+ setDefaultChannelLayout (avCodecContext_, static_cast <int >(wf_.sizes ()[0 ]));
179180
180181 int status = avcodec_open2 (avCodecContext_.get (), avCodec, nullptr );
181182 TORCH_CHECK (
@@ -206,9 +207,12 @@ torch::Tensor AudioEncoder::encodeToTensor() {
206207}
207208
208209void AudioEncoder::encode () {
209- // TODO-ENCODING: Need to check, but consecutive calls to encode() are
210- // probably invalid. We can address this once we (re)design the public and
211- // private encoding APIs.
210+ // To be on the safe side we enforce that encode() can only be called once on
211+ // an encoder object. Whether this is actually necessary is unknown, so this
212+ // may be relaxed if needed.
213+ TORCH_CHECK (!encodeWasCalled_, " Cannot call encode() twice." );
214+ encodeWasCalled_ = true ;
215+
212216 UniqueAVFrame avFrame (av_frame_alloc ());
213217 TORCH_CHECK (avFrame != nullptr , " Couldn't allocate AVFrame." );
214218 // Default to 256 like in torchaudio
@@ -322,14 +326,17 @@ void AudioEncoder::encodeInnerLoop(
322326 ReferenceAVPacket packet (autoAVPacket);
323327 status = avcodec_receive_packet (avCodecContext_.get (), packet.get ());
324328 if (status == AVERROR (EAGAIN) || status == AVERROR_EOF) {
325- // TODO-ENCODING this is from TorchAudio, probably needed, but not sure.
326- // if (status == AVERROR_EOF) {
327- // status = av_interleaved_write_frame(avFormatContext_.get(),
328- // nullptr); TORCH_CHECK(
329- // status == AVSUCCESS,
330- // "Failed to flush packet ",
331- // getFFMPEGErrorStringFromErrorCode(status));
332- // }
329+ if (status == AVERROR_EOF) {
330+ // Flush the packets that were potentially buffered by
331+ // av_interleaved_write_frame(). See corresponding block in
332+ // TorchAudio:
333+ // https://github.com/pytorch/audio/blob/d60ce09e2c532d5bf2e05619e700ab520543465e/src/libtorio/ffmpeg/stream_writer/encoder.cpp#L21
334+ status = av_interleaved_write_frame (avFormatContext_.get (), nullptr );
335+ TORCH_CHECK (
336+ status == AVSUCCESS,
337+ " Failed to flush packet: " ,
338+ getFFMPEGErrorStringFromErrorCode (status));
339+ }
333340 return ;
334341 }
335342 TORCH_CHECK (
0 commit comments