Skip to content

Commit

Permalink
Move FileObj to dedicated source
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed May 31, 2022
1 parent b56f60b commit b1d1171
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 78 deletions.
1 change: 1 addition & 0 deletions torchaudio/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
if(USE_FFMPEG)
set(
FFMPEG_EXTENSION_SOURCES
ffmpeg/pybind/typedefs.cpp
ffmpeg/pybind/pybind.cpp
ffmpeg/pybind/stream_reader.cpp
)
Expand Down
70 changes: 0 additions & 70 deletions torchaudio/csrc/ffmpeg/pybind/stream_reader.cpp
Original file line number Diff line number Diff line change
@@ -1,77 +1,7 @@
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/pybind/stream_reader.h>

namespace torchaudio {
namespace ffmpeg {
namespace {

static int read_function(void* opaque, uint8_t* buf, int buf_size) {
FileObj* fileobj = static_cast<FileObj*>(opaque);
buf_size = FFMIN(buf_size, fileobj->buffer_size);

int num_read = 0;
while (num_read < buf_size) {
int request = buf_size - num_read;
auto chunk = static_cast<std::string>(
static_cast<py::bytes>(fileobj->fileobj.attr("read")(request)));
auto chunk_len = chunk.length();
if (chunk_len == 0) {
break;
}
if (chunk_len > request) {
std::ostringstream message;
message
<< "Requested up to " << request << " bytes but, "
<< "received " << chunk_len << " bytes. "
<< "The given object does not confirm to read protocol of file object.";
throw std::runtime_error(message.str());
}
memcpy(buf, chunk.data(), chunk_len);
buf += chunk_len;
num_read += chunk_len;
}
return num_read == 0 ? AVERROR_EOF : num_read;
}

static int64_t seek_function(void* opaque, int64_t offset, int whence) {
// We do not know the file size.
if (whence == AVSEEK_SIZE) {
return AVERROR(EIO);
}
FileObj* fileobj = static_cast<FileObj*>(opaque);
return py::cast<int64_t>(fileobj->fileobj.attr("seek")(offset, whence));
}

AVIOContextPtr get_io_context(FileObj* opaque, int buffer_size) {
uint8_t* buffer = static_cast<uint8_t*>(av_malloc(buffer_size));
if (!buffer) {
throw std::runtime_error("Failed to allocate buffer.");
}

// If avio_alloc_context succeeds, then buffer will be cleaned up by
// AVIOContextPtr destructor.
// If avio_alloc_context fails, we need to clean up by ourselves.
AVIOContext* av_io_ctx = avio_alloc_context(
buffer,
buffer_size,
0,
static_cast<void*>(opaque),
&read_function,
nullptr,
py::hasattr(opaque->fileobj, "seek") ? &seek_function : nullptr);

if (!av_io_ctx) {
av_freep(&buffer);
throw std::runtime_error("Failed to allocate AVIO context.");
}
return AVIOContextPtr{av_io_ctx};
}
} // namespace

FileObj::FileObj(py::object fileobj_, int buffer_size)
: fileobj(fileobj_),
buffer_size(buffer_size),
pAVIO(get_io_context(this, buffer_size)) {}

StreamReaderFileObj::StreamReaderFileObj(
py::object fileobj_,
Expand Down
9 changes: 1 addition & 8 deletions torchaudio/csrc/ffmpeg/pybind/stream_reader.h
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
#pragma once
#include <torch/extension.h>
#include <torchaudio/csrc/ffmpeg/pybind/typedefs.h>
#include <torchaudio/csrc/ffmpeg/stream_reader_wrapper.h>

namespace torchaudio {
namespace ffmpeg {

struct FileObj {
py::object fileobj;
int buffer_size;
AVIOContextPtr pAVIO;
FileObj(py::object fileobj, int buffer_size);
};

// The reason we inherit FileObj instead of making it an attribute
// is so that FileObj is instantiated first.
// AVIOContext must be initialized before AVFormat, and outlive AVFormat.
Expand Down
76 changes: 76 additions & 0 deletions torchaudio/csrc/ffmpeg/pybind/typedefs.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#include <torchaudio/csrc/ffmpeg/pybind/typedefs.h>

namespace torchaudio {
namespace ffmpeg {
namespace {

static int read_function(void* opaque, uint8_t* buf, int buf_size) {
FileObj* fileobj = static_cast<FileObj*>(opaque);
buf_size = FFMIN(buf_size, fileobj->buffer_size);

int num_read = 0;
while (num_read < buf_size) {
int request = buf_size - num_read;
auto chunk = static_cast<std::string>(
static_cast<py::bytes>(fileobj->fileobj.attr("read")(request)));
auto chunk_len = chunk.length();
if (chunk_len == 0) {
break;
}
if (chunk_len > request) {
std::ostringstream message;
message
<< "Requested up to " << request << " bytes but, "
<< "received " << chunk_len << " bytes. "
<< "The given object does not confirm to read protocol of file object.";
throw std::runtime_error(message.str());
}
memcpy(buf, chunk.data(), chunk_len);
buf += chunk_len;
num_read += chunk_len;
}
return num_read == 0 ? AVERROR_EOF : num_read;
}

static int64_t seek_function(void* opaque, int64_t offset, int whence) {
// We do not know the file size.
if (whence == AVSEEK_SIZE) {
return AVERROR(EIO);
}
FileObj* fileobj = static_cast<FileObj*>(opaque);
return py::cast<int64_t>(fileobj->fileobj.attr("seek")(offset, whence));
}

AVIOContextPtr get_io_context(FileObj* opaque, int buffer_size) {
uint8_t* buffer = static_cast<uint8_t*>(av_malloc(buffer_size));
if (!buffer) {
throw std::runtime_error("Failed to allocate buffer.");
}

// If avio_alloc_context succeeds, then buffer will be cleaned up by
// AVIOContextPtr destructor.
// If avio_alloc_context fails, we need to clean up by ourselves.
AVIOContext* av_io_ctx = avio_alloc_context(
buffer,
buffer_size,
0,
static_cast<void*>(opaque),
&read_function,
nullptr,
py::hasattr(opaque->fileobj, "seek") ? &seek_function : nullptr);

if (!av_io_ctx) {
av_freep(&buffer);
throw std::runtime_error("Failed to allocate AVIO context.");
}
return AVIOContextPtr{av_io_ctx};
}
} // namespace

FileObj::FileObj(py::object fileobj_, int buffer_size)
: fileobj(fileobj_),
buffer_size(buffer_size),
pAVIO(get_io_context(this, buffer_size)) {}

} // namespace ffmpeg
} // namespace torchaudio
16 changes: 16 additions & 0 deletions torchaudio/csrc/ffmpeg/pybind/typedefs.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#pragma once
#include <torch/extension.h>
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>

namespace torchaudio {
namespace ffmpeg {

struct FileObj {
py::object fileobj;
int buffer_size;
AVIOContextPtr pAVIO;
FileObj(py::object fileobj, int buffer_size);
};

} // namespace ffmpeg
} // namespace torchaudio

0 comments on commit b1d1171

Please sign in to comment.