Skip to content

Commit

Permalink
Merge pull request #9182 from JiayiFeng/dev_MultipleReader
Browse files Browse the repository at this point in the history
Multi-threaded reader in C++
  • Loading branch information
JiayiFeng authored Mar 21, 2018
2 parents e4bd63d + 2532b92 commit 7c041e4
Show file tree
Hide file tree
Showing 8 changed files with 348 additions and 4 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/operators/reader/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ function(reader_library TARGET_NAME)
PARENT_SCOPE)
endfunction()

reader_library(open_files_op SRCS open_files_op.cc)
reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc)
reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc)
reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,13 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
};

void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
if (!HasNext()) {
PADDLE_THROW("There is no next data!");
}

if (local_buffer_.payloads_.empty()) {
buffer_->Receive(&local_buffer_);
}

*out = local_buffer_.payloads_;
local_buffer_.payloads_.clear();
if (local_buffer_.ctx_) {
Expand Down
212 changes: 212 additions & 0 deletions paddle/fluid/operators/reader/open_files_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"

namespace paddle {
namespace operators {
namespace reader {

class MultipleReader : public framework::ReaderBase {
public:
MultipleReader(const std::vector<std::string>& file_names,
const std::vector<framework::DDim>& dims, size_t thread_num)
: file_names_(file_names), dims_(dims) {
prefetchers_.resize(thread_num);
StartNewScheduler();
}

void ReadNext(std::vector<framework::LoDTensor>* out) override;
bool HasNext() const override;
void ReInit() override;

~MultipleReader() { EndScheduler(); }

private:
void StartNewScheduler();
void EndScheduler();
void ScheduleThreadFunc();
void PrefetchThreadFunc(std::string file_name, size_t thread_idx);

std::vector<std::string> file_names_;
std::vector<framework::DDim> dims_;
std::thread scheduler_;
std::vector<std::thread> prefetchers_;
framework::Channel<size_t>* waiting_file_idx_;
framework::Channel<size_t>* available_thread_idx_;
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
mutable std::vector<framework::LoDTensor> local_buffer_;
};

void MultipleReader::ReadNext(std::vector<framework::LoDTensor>* out) {
if (!HasNext()) {
PADDLE_THROW("There is no next data!");
}

if (local_buffer_.empty()) {
buffer_->Receive(&local_buffer_);
}
*out = local_buffer_;
local_buffer_.clear();
}

bool MultipleReader::HasNext() const {
return local_buffer_.empty() ? buffer_->Receive(&local_buffer_) : true;
}

void MultipleReader::ReInit() {
EndScheduler();
local_buffer_.clear();
StartNewScheduler();
}

void MultipleReader::StartNewScheduler() {
size_t thread_num = prefetchers_.size();
waiting_file_idx_ = framework::MakeChannel<size_t>(file_names_.size());
available_thread_idx_ = framework::MakeChannel<size_t>(thread_num);
buffer_ =
framework::MakeChannel<std::vector<framework::LoDTensor>>(thread_num);

for (size_t i = 0; i < file_names_.size(); ++i) {
waiting_file_idx_->Send(&i);
}
waiting_file_idx_->Close();
for (size_t i = 0; i < thread_num; ++i) {
available_thread_idx_->Send(&i);
}

scheduler_ = std::thread([this] { ScheduleThreadFunc(); });
}

void MultipleReader::EndScheduler() {
available_thread_idx_->Close();
buffer_->Close();
waiting_file_idx_->Close();
if (scheduler_.joinable()) {
scheduler_.join();
}
delete buffer_;
delete available_thread_idx_;
delete waiting_file_idx_;
}

void MultipleReader::ScheduleThreadFunc() {
VLOG(5) << "MultipleReader schedule thread starts.";
size_t completed_thread_num = 0;
size_t thread_idx;
while (available_thread_idx_->Receive(&thread_idx)) {
std::thread& prefetcher = prefetchers_[thread_idx];
if (prefetcher.joinable()) {
prefetcher.join();
}
size_t file_idx;
if (waiting_file_idx_->Receive(&file_idx)) {
// Still have files to read. Start a new prefetch thread.
std::string file_name = file_names_[file_idx];
prefetcher = std::thread([this, file_name, thread_idx] {
PrefetchThreadFunc(file_name, thread_idx);
});
} else {
// No more file to read.
++completed_thread_num;
if (completed_thread_num == prefetchers_.size()) {
buffer_->Close();
break;
}
}
}
// If users invoke ReInit() when scheduler is running, it will close the
// 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler
// to release their resource. So a check is needed before scheduler ends.
for (auto& p : prefetchers_) {
if (p.joinable()) {
p.join();
}
}
VLOG(5) << "MultipleReader schedule thread terminates.";
}

void MultipleReader::PrefetchThreadFunc(std::string file_name,
size_t thread_idx) {
VLOG(5) << "The prefetch thread of file '" << file_name << "' starts.";
std::unique_ptr<framework::ReaderBase> reader =
CreateReaderByFileName(file_name, dims_);
while (reader->HasNext()) {
std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins);
if (!buffer_->Send(&ins)) {
VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch "
"thread of file '"
<< file_name << "' will terminate.";
break;
}
}
if (!available_thread_idx_->Send(&thread_idx)) {
VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. "
"Fail to send thread_idx.";
}
VLOG(5) << "The prefetch thread of file '" << file_name << "' terminates.";
}

class OpenFilesOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;

private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
const auto& ranks = Attr<std::vector<int>>("ranks");
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
int(shape_concat.size()),
"The accumulate of all ranks should be equal to the "
"shape concat's length.");
const auto& file_names = Attr<std::vector<std::string>>("file_names");
PADDLE_ENFORCE(!file_names.empty(), "No file to be read!");
const size_t thread_num = Attr<int>("thread_num");

auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new MultipleReader(
file_names, RestoreShapes(shape_concat, ranks), thread_num));
}
};

class OpenFilesOpMaker : public FileReaderMakerBase {
public:
OpenFilesOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: FileReaderMakerBase(op_proto, op_checker) {
AddAttr<std::vector<std::string>>("file_names", "Files to be read.");
AddAttr<int>("thread_num", "The maximal concurrent prefetch thread number.")
.GreaterThan(0);

AddComment(R"DOC(
OpenFiles Operator
An OpenFilesOp creates a MultipleReader, which is able to
read data multi-threaded from multiple files.
)DOC");
}
};

} // namespace reader
} // namespace operators
} // namespace paddle

namespace reader = paddle::operators::reader;

REGISTER_FILE_READER_OPERATOR(open_files, reader::OpenFilesOp,
reader::OpenFilesOpMaker);
15 changes: 15 additions & 0 deletions paddle/fluid/operators/reader/reader_op_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,21 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry() {
return regs;
}

std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
const std::string& file_name, const std::vector<framework::DDim>& dims) {
size_t separator_pos = file_name.find_last_of(kFileFormatSeparator);
PADDLE_ENFORCE_NE(separator_pos, std::string::npos,
"File name illegal! A legal file name should be like: "
"[file_name].[file_format] (e.g., 'data_file.recordio').");
std::string filetype = file_name.substr(separator_pos + 1);

auto itor = FileReaderRegistry().find(filetype);
PADDLE_ENFORCE(itor != FileReaderRegistry().end(),
"No file reader registered for '%s' format.", filetype);
framework::ReaderBase* reader = (itor->second)(file_name, dims);
return std::unique_ptr<framework::ReaderBase>(reader);
}

FileReaderMakerBase::FileReaderMakerBase(
framework::OpProtoAndCheckerMaker::OpProto* op_proto,
framework::OpAttrChecker* op_checker)
Expand Down
9 changes: 7 additions & 2 deletions paddle/fluid/operators/reader/reader_op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ namespace paddle {
namespace operators {
namespace reader {

static constexpr char kFileFormatSeparator[] = ".";

using FileReaderCreator = std::function<framework::ReaderBase*(
const std::string&, const std::vector<framework::DDim>&)>;

Expand All @@ -29,12 +31,15 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry();
template <typename Reader>
int RegisterFileReader(const std::string& filetype) {
FileReaderRegistry()[filetype] = [](
const std::string& fn, const std::vector<paddle::framework::DDim>& dim) {
return new Reader(fn, dim);
const std::string& fn, const std::vector<framework::DDim>& dims) {
return new Reader(fn, dims);
};
return 0;
}

std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
const std::string& file_name, const std::vector<framework::DDim>& dims);

extern std::vector<framework::DDim> RestoreShapes(
const std::vector<int>& shape_concat, const std::vector<int>& ranks);

Expand Down
33 changes: 32 additions & 1 deletion python/paddle/fluid/layers/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

__all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
'read_file', 'create_shuffle_reader', 'create_double_buffer_reader'
'open_files', 'read_file', 'create_shuffle_reader',
'create_double_buffer_reader'
]


Expand Down Expand Up @@ -287,6 +288,36 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes):
startup_var)


def open_files(filenames, thread_num, shapes, lod_levels, dtypes):
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
shape_concat = []
ranks = []

for shape in shapes:
shape_concat.extend(shape)
ranks.append(len(shape))

var_name = unique_name('multiple_reader')

startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=var_name)
startup_blk.append_op(
type='open_files',
outputs={'Out': [startup_var]},
attrs={
'shape_concat': shape_concat,
'lod_levels': lod_levels,
'ranks': ranks,
'file_names': filenames,
'thread_num': thread_num
})

startup_var.desc.set_dtypes(dtypes)
startup_var.persistable = True
return _copy_reader_var_(default_main_program().current_block(),
startup_var)


def __create_decorated_reader__(op_type, reader, attrs):
var_name = unique_name(op_type)
startup_blk = default_startup_program().current_block()
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/fluid/tests/unittests/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
mnist.recordio
mnist_0.recordio
mnist_1.recordio
mnist_2.recordio
Loading

0 comments on commit 7c041e4

Please sign in to comment.