Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flush and reset internal state after seek #2264

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions test/torchaudio_unittest/prototype/io_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,22 @@ def test_stream_smoke_test(self):
if i >= 40:
break

def test_seek(self):
"""Calling `seek` multiple times should not segfault"""
s = Streamer(get_video_asset())
for i in range(10):
s.seek(i)
for _ in range(0):
s.seek(0)
for i in range(10, 0, -1):
s.seek(i)

def test_seek_negative(self):
"""Calling `seek` with negative value should raise an exception"""
s = Streamer(get_video_asset())
with self.assertRaises(ValueError):
s.seek(-1.0)


@skipIfNoFFmpeg
class StreamerAudioTest(TempDirMixin, TorchaudioTestCase):
Expand Down Expand Up @@ -345,6 +361,21 @@ def test_audio_seek(self, dtype, num_channels):
(output,) = s.pop_chunks()
self.assertEqual(expected, output)

def test_audio_seek_multiple(self):
"""Calling `seek` after streaming is started should change the position properly"""
path, original = self._get_reference_wav(1, dtype="int16", num_channels=2, num_frames=30)

s = Streamer(path)
s.add_audio_stream(frames_per_chunk=-1)

ts = list(range(20)) + list(range(20, 0, -1)) + list(range(20))
for t in ts:
s.seek(float(t))
s.process_all_packets()
(output,) = s.pop_chunks()
expected = original[t:, :]
self.assertEqual(expected, output)

@nested_params(
[
(18, 6, 3), # num_frames is divisible by frames_per_chunk
Expand Down
4 changes: 4 additions & 0 deletions torchaudio/csrc/ffmpeg/buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,5 +285,9 @@ torch::Tensor Buffer::pop_all() {
return torch::cat(ret, 0);
}

void Buffer::flush() {
chunks.clear();
}

} // namespace ffmpeg
} // namespace torchaudio
2 changes: 2 additions & 0 deletions torchaudio/csrc/ffmpeg/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class Buffer {

c10::optional<torch::Tensor> pop_chunk();

void flush();

private:
virtual torch::Tensor pop_one_chunk() = 0;
torch::Tensor pop_all();
Expand Down
4 changes: 4 additions & 0 deletions torchaudio/csrc/ffmpeg/decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,9 @@ int Decoder::get_frame(AVFrame* pFrame) {
return avcodec_receive_frame(pCodecContext, pFrame);
}

void Decoder::flush_buffer() {
avcodec_flush_buffers(pCodecContext);
}

} // namespace ffmpeg
} // namespace torchaudio
2 changes: 2 additions & 0 deletions torchaudio/csrc/ffmpeg/decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class Decoder {
int process_packet(AVPacket* pPacket);
// Fetch a decoded frame
int get_frame(AVFrame* pFrame);
// Flush buffer (for seek)
void flush_buffer();
};

} // namespace ffmpeg
Expand Down
4 changes: 4 additions & 0 deletions torchaudio/csrc/ffmpeg/ffmpeg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,5 +156,9 @@ AVFilterGraph* get_filter_graph() {
} // namespace
AVFilterGraphPtr::AVFilterGraphPtr()
: Wrapper<AVFilterGraph, AVFilterGraphDeleter>(get_filter_graph()) {}

void AVFilterGraphPtr::reset() {
ptr.reset(get_filter_graph());
}
} // namespace ffmpeg
} // namespace torchaudio
1 change: 1 addition & 0 deletions torchaudio/csrc/ffmpeg/ffmpeg.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ struct AVFilterGraphDeleter {
};
struct AVFilterGraphPtr : public Wrapper<AVFilterGraph, AVFilterGraphDeleter> {
AVFilterGraphPtr();
void reset();
};
} // namespace ffmpeg
} // namespace torchaudio
35 changes: 23 additions & 12 deletions torchaudio/csrc/ffmpeg/filter_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ FilterGraph::FilterGraph(
AVRational time_base,
AVCodecParameters* codecpar,
std::string filter_description)
: filter_description(filter_description) {
add_src(time_base, codecpar);
add_sink();
add_process();
create_filter();
: input_time_base(time_base),
codecpar(codecpar),
filter_description(std::move(filter_description)),
media_type(codecpar->codec_type) {
init();
}

////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -62,18 +62,29 @@ std::string get_video_src_args(

} // namespace

void FilterGraph::add_src(AVRational time_base, AVCodecParameters* codecpar) {
if (media_type != AVMEDIA_TYPE_UNKNOWN) {
throw std::runtime_error("Source buffer is already allocated.");
}
media_type = codecpar->codec_type;
void FilterGraph::init() {
add_src();
add_sink();
add_process();
create_filter();
}

void FilterGraph::reset() {
pFilterGraph.reset();
buffersrc_ctx = nullptr;
buffersink_ctx = nullptr;

init();
}

void FilterGraph::add_src() {
std::string args;
switch (media_type) {
case AVMEDIA_TYPE_AUDIO:
args = get_audio_src_args(time_base, codecpar);
args = get_audio_src_args(input_time_base, codecpar);
break;
case AVMEDIA_TYPE_VIDEO:
args = get_video_src_args(time_base, codecpar);
args = get_video_src_args(input_time_base, codecpar);
break;
default:
throw std::runtime_error("Only audio/video are supported.");
Expand Down
17 changes: 14 additions & 3 deletions torchaudio/csrc/ffmpeg/filter_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,20 @@ namespace torchaudio {
namespace ffmpeg {

class FilterGraph {
AVMediaType media_type = AVMEDIA_TYPE_UNKNOWN;
// Parameters required for `reset`
// Recreats the underlying filter_graph struct
AVRational input_time_base;
AVCodecParameters* codecpar;
std::string filter_description;

// Constant just for convenient access.
AVMediaType media_type;

AVFilterGraphPtr pFilterGraph;
// AVFilterContext is freed as a part of AVFilterGraph
// so we do not manage the resource.
AVFilterContext* buffersrc_ctx = nullptr;
AVFilterContext* buffersink_ctx = nullptr;
const std::string filter_description;

public:
FilterGraph(
Expand All @@ -35,8 +42,12 @@ class FilterGraph {
//////////////////////////////////////////////////////////////////////////////
// Configuration methods
//////////////////////////////////////////////////////////////////////////////
void init();

void reset();

private:
void add_src(AVRational time_base, AVCodecParameters* codecpar);
void add_src();

void add_sink();

Expand Down
6 changes: 6 additions & 0 deletions torchaudio/csrc/ffmpeg/sink.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,11 @@ int Sink::process_frame(AVFrame* pFrame) {
bool Sink::is_buffer_ready() const {
return buffer->is_ready();
}

void Sink::flush() {
filter.reset();
buffer->flush();
}

} // namespace ffmpeg
} // namespace torchaudio
2 changes: 2 additions & 0 deletions torchaudio/csrc/ffmpeg/sink.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class Sink {

int process_frame(AVFrame* frame);
bool is_buffer_ready() const;

void flush();
};

} // namespace ffmpeg
Expand Down
7 changes: 7 additions & 0 deletions torchaudio/csrc/ffmpeg/stream_processor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ int StreamProcessor::process_packet(AVPacket* packet) {
return ret;
}

void StreamProcessor::flush() {
decoder.flush_buffer();
for (auto& ite : sinks) {
ite.second.flush();
}
}

// 0: some kind of success
// <0: Some error happened
int StreamProcessor::send_frame(AVFrame* pFrame) {
Expand Down
4 changes: 4 additions & 0 deletions torchaudio/csrc/ffmpeg/stream_processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ class StreamProcessor {
// - Sending NULL will drain (flush) the internal
int process_packet(AVPacket* packet);

// flush the internal buffer of decoder.
// To be use when seeking
void flush();

private:
int send_frame(AVFrame* pFrame);

Expand Down
9 changes: 9 additions & 0 deletions torchaudio/csrc/ffmpeg/streamer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,20 @@ bool Streamer::is_buffer_ready() const {
// Configure methods
////////////////////////////////////////////////////////////////////////////////
void Streamer::seek(double timestamp) {
if (timestamp < 0) {
throw std::invalid_argument("timestamp must be non-negative.");
}

int64_t ts = static_cast<int64_t>(timestamp * AV_TIME_BASE);
int ret = avformat_seek_file(pFormatContext, -1, INT64_MIN, ts, INT64_MAX, 0);
if (ret < 0) {
throw std::runtime_error("Failed to seek. (" + av_err2string(ret) + ".)");
}
for (const auto& it : processors) {
if (it) {
it->flush();
}
}
}

void Streamer::add_audio_stream(
Expand Down