From 550622529cf09ae4cb11c46817d73c9a1a5c88a2 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Sun, 18 Mar 2018 20:04:03 +0800 Subject: [PATCH 1/6] Add MultipleReader and open_files_op --- paddle/fluid/operators/reader/CMakeLists.txt | 1 + .../reader/create_double_buffer_reader_op.cc | 5 +- .../fluid/operators/reader/open_files_op.cc | 199 ++++++++++++++++++ .../operators/reader/reader_op_registry.h | 22 +- 4 files changed, 224 insertions(+), 3 deletions(-) create mode 100644 paddle/fluid/operators/reader/open_files_op.cc diff --git a/paddle/fluid/operators/reader/CMakeLists.txt b/paddle/fluid/operators/reader/CMakeLists.txt index 744bd3b7ef71f8..1254783d69a87b 100644 --- a/paddle/fluid/operators/reader/CMakeLists.txt +++ b/paddle/fluid/operators/reader/CMakeLists.txt @@ -20,5 +20,6 @@ reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc) reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc) reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc) reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc) +reader_library(open_files_op SRCS open_files_op.cc) # Export local libraries to parent set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE) diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index d0de092947eb04..447fae10535c1b 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -120,10 +120,13 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { }; void DoubleBufferReader::ReadNext(std::vector* out) { + if (!HasNext()) { + PADDLE_THROW("There is no next data!"); + } + if (local_buffer_.payloads_.empty()) { buffer_->Receive(&local_buffer_); } - *out = local_buffer_.payloads_; local_buffer_.payloads_.clear(); if (local_buffer_.ctx_) { diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc new file mode 100644 index 00000000000000..473c002e93a6db --- /dev/null +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -0,0 +1,199 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/channel.h" +#include "paddle/fluid/operators/reader/reader_op_registry.h" + +namespace paddle { +namespace operators { +namespace reader { + +class MultipleReader : public framework::ReaderBase { + public: + struct Quota {}; + + MultipleReader(const std::vector& file_names, + const std::vector& dims, size_t thread_num) + : file_names_(file_names), dims_(dims), thread_num_(thread_num) { + PADDLE_ENFORCE_GT(thread_num_, 0); + StartNewScheduler(); + } + + void ReadNext(std::vector* out) override; + bool HasNext() const override; + void ReInit() override; + + private: + void StartNewScheduler(); + void ScheduleThreadFunc(); + void PrefetchThreadFunc(std::string file_name); + + std::vector file_names_; + std::vector dims_; + size_t thread_num_; + framework::Channel* waiting_file_idx_; + framework::Channel* thread_quotas_; + framework::Channel>* buffer_; + mutable std::vector local_buffer_; +}; + +void MultipleReader::ReadNext(std::vector* out) { + if (!HasNext()) { + PADDLE_THROW("There is no next data!"); + } + + if (local_buffer_.empty()) { + buffer_->Receive(&local_buffer_); + } + *out = local_buffer_; + local_buffer_.clear(); +} + +bool MultipleReader::HasNext() const { + return local_buffer_.empty() ? buffer_->Receive(&local_buffer_) : true; +} + +void MultipleReader::ReInit() { + buffer_->Close(); + thread_quotas_->Close(); + waiting_file_idx_->Close(); + local_buffer_.clear(); + + StartNewScheduler(); +} + +void MultipleReader::StartNewScheduler() { + waiting_file_idx_ = framework::MakeChannel(file_names_.size()); + thread_quotas_ = framework::MakeChannel(thread_num_); + buffer_ = + framework::MakeChannel>(thread_num_); + + for (size_t i = 0; i < file_names_.size(); ++i) { + waiting_file_idx_->Send(&i); + } + waiting_file_idx_->Close(); + for (size_t i = 0; i < thread_num_; ++i) { + Quota quota; + thread_quotas_->Send("a); + } + + std::thread scheduler([this] { ScheduleThreadFunc(); }); + scheduler.detach(); +} + +void MultipleReader::ScheduleThreadFunc() { + VLOG(5) << "MultipleReader schedule thread starts."; + size_t completed_thread_num = 0; + Quota quota; + while (thread_quotas_->Receive("a)) { + size_t file_idx; + if (waiting_file_idx_->Receive(&file_idx)) { + // Still have files to read. Start a new prefetch thread. + std::string file_name = file_names_[file_idx]; + std::thread prefetcher( + [this, file_name] { PrefetchThreadFunc(file_name); }); + prefetcher.detach(); + } else { + // No more file to read. + ++completed_thread_num; + if (completed_thread_num == thread_num_) { + thread_quotas_->Close(); + buffer_->Close(); + break; + } + } + } + VLOG(5) << "MultipleReader schedule thread terminates."; +} + +void MultipleReader::PrefetchThreadFunc(std::string file_name) { + VLOG(5) << "The prefetch thread of file '" << file_name << "' starts."; + std::unique_ptr reader = + CreateReaderByFileName(file_name, dims_); + while (reader->HasNext()) { + std::vector ins; + reader->ReadNext(&ins); + if (!buffer_->Send(&ins)) { + VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch " + "thread of file '" + << file_name << "' will terminate."; + break; + } + } + Quota quota; + thread_quotas_->Send("a); + VLOG(5) << "The prefetch thread of file '" << file_name << "' terminates."; +} + +class OpenFilesOp : public framework::OperatorBase { + public: + using framework::OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { + const auto& shape_concat = Attr>("shape_concat"); + const auto& ranks = Attr>("ranks"); + PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty()); + PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0), + int(shape_concat.size()), + "The accumulate of all ranks should be equal to the " + "shape concat's length."); + const auto& file_names = Attr>("file_names"); + PADDLE_ENFORCE(!file_names.empty(), "No file to be read!"); + const size_t thread_num = Attr("thread_num"); + + auto* out = scope.FindVar(Output("Out")) + ->template GetMutable(); + out->Reset(new MultipleReader( + file_names, RestoreShapes(shape_concat, ranks), thread_num)); + } +}; + +class OpenFilesOpMaker : public framework::OpProtoAndCheckerMaker { + public: + OpenFilesOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(op_proto, op_checker) { + AddComment(R"DOC( + OpenFiles Operator + + An OpenFilesOp creates a MultipleReader, which is able to + read data multi-threaded from multiple files. + )DOC"); + AddOutput("Out", "(ReaderHolder) The created MultipleReader."); + AddAttr>("shape_concat", + "The concat of all data's shapes."); + AddAttr>( + "ranks", + "The ranks of each data." + "e.g." + "shape_concat = [2,3,4,5,6]" + "ranks = [3,2]" + "It means the reader will generate two data each time," + "whose shapes are [2,3,4] and [5,6] respectively."); + AddAttr>("lod_levels", "The LoD levels of each data."); + AddAttr>("file_names", "Files to be read."); + AddAttr("thread_num", "The maximal concurrent prefetch thread number.") + .GreaterThan(0); + } +}; + +} // namespace reader +} // namespace operators +} // namespace paddle + +namespace reader = paddle::operators::reader; + +REGISTER_FILE_READER_OPERATOR(open_files, reader::OpenFilesOp, + reader::OpenFilesOpMaker); \ No newline at end of file diff --git a/paddle/fluid/operators/reader/reader_op_registry.h b/paddle/fluid/operators/reader/reader_op_registry.h index 58f9b4ba355465..feab7c63a3eeea 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.h +++ b/paddle/fluid/operators/reader/reader_op_registry.h @@ -21,6 +21,8 @@ namespace paddle { namespace operators { namespace reader { +static constexpr char kFileFormatSeparator[] = ":"; + using FileReaderCreator = std::function&)>; @@ -29,12 +31,28 @@ std::unordered_map& FileReaderRegistry(); template int RegisterFileReader(const std::string& filetype) { FileReaderRegistry()[filetype] = []( - const std::string& fn, const std::vector& dim) { - return new Reader(fn, dim); + const std::string& fn, const std::vector& dims) { + return new Reader(fn, dims); }; return 0; } +std::unique_ptr CreateReaderByFileName( + const std::string& file_name, const std::vector& dims) { + size_t separator_pos = file_name.find(kFileFormatSeparator); + PADDLE_ENFORCE_NE(separator_pos, std::string::npos, + "File name illegal! A legal file name should be like: " + "[file_format]:[file_name] (e.g., 'recordio:data_file')."); + std::string filetype = file_name.substr(0, separator_pos); + std::string f_name = file_name.substr(separator_pos + 1); + + auto itor = FileReaderRegistry().find(filetype); + PADDLE_ENFORCE(itor != FileReaderRegistry().end(), + "No file reader registered for '%s' format.", filetype); + framework::ReaderBase* reader = (itor->second)(f_name, dims); + return std::unique_ptr(reader); +} + extern std::vector RestoreShapes( const std::vector& shape_concat, const std::vector& ranks); From 3d677b1eca75733adbc1939dd0a50cbacead6718 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Sun, 18 Mar 2018 20:29:48 +0800 Subject: [PATCH 2/6] fix compile errors and make OpenFilesOpMaker derived from FileReaderMakerBase --- paddle/fluid/operators/reader/CMakeLists.txt | 2 +- .../fluid/operators/reader/open_files_op.cc | 25 ++++++------------- .../operators/reader/reader_op_registry.cc | 16 ++++++++++++ .../operators/reader/reader_op_registry.h | 15 +---------- 4 files changed, 25 insertions(+), 33 deletions(-) diff --git a/paddle/fluid/operators/reader/CMakeLists.txt b/paddle/fluid/operators/reader/CMakeLists.txt index 1254783d69a87b..4a43fc02d2189e 100644 --- a/paddle/fluid/operators/reader/CMakeLists.txt +++ b/paddle/fluid/operators/reader/CMakeLists.txt @@ -15,11 +15,11 @@ function(reader_library TARGET_NAME) PARENT_SCOPE) endfunction() +reader_library(open_files_op SRCS open_files_op.cc) reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc) reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc) reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc) reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc) reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc) -reader_library(open_files_op SRCS open_files_op.cc) # Export local libraries to parent set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE) diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc index 473c002e93a6db..6b62e1db490760 100644 --- a/paddle/fluid/operators/reader/open_files_op.cc +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -161,31 +161,20 @@ class OpenFilesOp : public framework::OperatorBase { } }; -class OpenFilesOpMaker : public framework::OpProtoAndCheckerMaker { +class OpenFilesOpMaker : public FileReaderMakerBase { public: OpenFilesOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(op_proto, op_checker) { + : FileReaderMakerBase(op_proto, op_checker) { + AddAttr>("file_names", "Files to be read."); + AddAttr("thread_num", "The maximal concurrent prefetch thread number.") + .GreaterThan(0); + AddComment(R"DOC( OpenFiles Operator An OpenFilesOp creates a MultipleReader, which is able to read data multi-threaded from multiple files. )DOC"); - AddOutput("Out", "(ReaderHolder) The created MultipleReader."); - AddAttr>("shape_concat", - "The concat of all data's shapes."); - AddAttr>( - "ranks", - "The ranks of each data." - "e.g." - "shape_concat = [2,3,4,5,6]" - "ranks = [3,2]" - "It means the reader will generate two data each time," - "whose shapes are [2,3,4] and [5,6] respectively."); - AddAttr>("lod_levels", "The LoD levels of each data."); - AddAttr>("file_names", "Files to be read."); - AddAttr("thread_num", "The maximal concurrent prefetch thread number.") - .GreaterThan(0); } }; @@ -196,4 +185,4 @@ class OpenFilesOpMaker : public framework::OpProtoAndCheckerMaker { namespace reader = paddle::operators::reader; REGISTER_FILE_READER_OPERATOR(open_files, reader::OpenFilesOp, - reader::OpenFilesOpMaker); \ No newline at end of file + reader::OpenFilesOpMaker); diff --git a/paddle/fluid/operators/reader/reader_op_registry.cc b/paddle/fluid/operators/reader/reader_op_registry.cc index 0ba4f385443174..05d79c76d5ab0e 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.cc +++ b/paddle/fluid/operators/reader/reader_op_registry.cc @@ -36,6 +36,22 @@ std::unordered_map& FileReaderRegistry() { return regs; } +std::unique_ptr CreateReaderByFileName( + const std::string& file_name, const std::vector& dims) { + size_t separator_pos = file_name.find(kFileFormatSeparator); + PADDLE_ENFORCE_NE(separator_pos, std::string::npos, + "File name illegal! A legal file name should be like: " + "[file_format]:[file_name] (e.g., 'recordio:data_file')."); + std::string filetype = file_name.substr(0, separator_pos); + std::string f_name = file_name.substr(separator_pos + 1); + + auto itor = FileReaderRegistry().find(filetype); + PADDLE_ENFORCE(itor != FileReaderRegistry().end(), + "No file reader registered for '%s' format.", filetype); + framework::ReaderBase* reader = (itor->second)(f_name, dims); + return std::unique_ptr(reader); +} + FileReaderMakerBase::FileReaderMakerBase( framework::OpProtoAndCheckerMaker::OpProto* op_proto, framework::OpAttrChecker* op_checker) diff --git a/paddle/fluid/operators/reader/reader_op_registry.h b/paddle/fluid/operators/reader/reader_op_registry.h index feab7c63a3eeea..dd19b982dad862 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.h +++ b/paddle/fluid/operators/reader/reader_op_registry.h @@ -38,20 +38,7 @@ int RegisterFileReader(const std::string& filetype) { } std::unique_ptr CreateReaderByFileName( - const std::string& file_name, const std::vector& dims) { - size_t separator_pos = file_name.find(kFileFormatSeparator); - PADDLE_ENFORCE_NE(separator_pos, std::string::npos, - "File name illegal! A legal file name should be like: " - "[file_format]:[file_name] (e.g., 'recordio:data_file')."); - std::string filetype = file_name.substr(0, separator_pos); - std::string f_name = file_name.substr(separator_pos + 1); - - auto itor = FileReaderRegistry().find(filetype); - PADDLE_ENFORCE(itor != FileReaderRegistry().end(), - "No file reader registered for '%s' format.", filetype); - framework::ReaderBase* reader = (itor->second)(f_name, dims); - return std::unique_ptr(reader); -} + const std::string& file_name, const std::vector& dims); extern std::vector RestoreShapes( const std::vector& shape_concat, const std::vector& ranks); From 87ac675ae7365cdc8afc8f12503df962ce9aaabc Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Sun, 18 Mar 2018 23:49:11 +0800 Subject: [PATCH 3/6] Add python wrapper for open_files_op --- python/paddle/fluid/layers/io.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index 9c91f395e7c9d7..89153f325bed5b 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -287,6 +287,36 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes): startup_var) +def open_files(filenames, thread_num, shapes, lod_levels, dtypes): + dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes] + shape_concat = [] + ranks = [] + + for shape in shapes: + shape_concat.extend(shape) + ranks.append(len(shape)) + + var_name = unique_name('multiple_reader') + + startup_blk = default_startup_program().current_block() + startup_var = startup_blk.create_var(name=var_name) + startup_blk.append_op( + type='open_files', + outputs={'Out': [startup_var]}, + attrs={ + 'shape_concat': shape_concat, + 'lod_levels': lod_levels, + 'ranks': ranks, + 'filename': filenames, + 'thread_num': thread_num + }) + + startup_var.desc.set_dtypes(dtypes) + startup_var.persistable = True + return _copy_reader_var_(default_main_program().current_block(), + startup_var) + + def __create_decorated_reader__(op_type, reader, attrs): var_name = unique_name(op_type) startup_blk = default_startup_program().current_block() From a2981f5c5018c23aa969389c64a329e53f8cf290 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 20 Mar 2018 15:16:42 +0800 Subject: [PATCH 4/6] fix a bug --- .../fluid/operators/reader/open_files_op.cc | 79 ++++++++++++------- 1 file changed, 50 insertions(+), 29 deletions(-) diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc index 6b62e1db490760..49cdf5365c9964 100644 --- a/paddle/fluid/operators/reader/open_files_op.cc +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -21,12 +21,10 @@ namespace reader { class MultipleReader : public framework::ReaderBase { public: - struct Quota {}; - MultipleReader(const std::vector& file_names, const std::vector& dims, size_t thread_num) - : file_names_(file_names), dims_(dims), thread_num_(thread_num) { - PADDLE_ENFORCE_GT(thread_num_, 0); + : file_names_(file_names), dims_(dims) { + prefetchers_.resize(thread_num); StartNewScheduler(); } @@ -34,16 +32,20 @@ class MultipleReader : public framework::ReaderBase { bool HasNext() const override; void ReInit() override; + ~MultipleReader() { EndScheduler(); } + private: void StartNewScheduler(); + void EndScheduler(); void ScheduleThreadFunc(); - void PrefetchThreadFunc(std::string file_name); + void PrefetchThreadFunc(std::string file_name, size_t thread_idx); std::vector file_names_; std::vector dims_; - size_t thread_num_; + std::thread scheduler_; + std::vector prefetchers_; framework::Channel* waiting_file_idx_; - framework::Channel* thread_quotas_; + framework::Channel* available_thread_idx_; framework::Channel>* buffer_; mutable std::vector local_buffer_; }; @@ -65,59 +67,76 @@ bool MultipleReader::HasNext() const { } void MultipleReader::ReInit() { - buffer_->Close(); - thread_quotas_->Close(); - waiting_file_idx_->Close(); + EndScheduler(); local_buffer_.clear(); - StartNewScheduler(); } void MultipleReader::StartNewScheduler() { + size_t thread_num = prefetchers_.size(); waiting_file_idx_ = framework::MakeChannel(file_names_.size()); - thread_quotas_ = framework::MakeChannel(thread_num_); + available_thread_idx_ = framework::MakeChannel(thread_num); buffer_ = - framework::MakeChannel>(thread_num_); + framework::MakeChannel>(thread_num); for (size_t i = 0; i < file_names_.size(); ++i) { waiting_file_idx_->Send(&i); } waiting_file_idx_->Close(); - for (size_t i = 0; i < thread_num_; ++i) { - Quota quota; - thread_quotas_->Send("a); + for (size_t i = 0; i < thread_num; ++i) { + available_thread_idx_->Send(&i); } - std::thread scheduler([this] { ScheduleThreadFunc(); }); - scheduler.detach(); + scheduler_ = std::thread([this] { ScheduleThreadFunc(); }); +} + +void MultipleReader::EndScheduler() { + available_thread_idx_->Close(); + buffer_->Close(); + waiting_file_idx_->Close(); + scheduler_.join(); + delete buffer_; + delete available_thread_idx_; + delete waiting_file_idx_; } void MultipleReader::ScheduleThreadFunc() { VLOG(5) << "MultipleReader schedule thread starts."; size_t completed_thread_num = 0; - Quota quota; - while (thread_quotas_->Receive("a)) { + size_t thread_idx; + while (available_thread_idx_->Receive(&thread_idx)) { + std::thread& prefetcher = prefetchers_[thread_idx]; + if (prefetcher.joinable()) { + prefetcher.join(); + } size_t file_idx; if (waiting_file_idx_->Receive(&file_idx)) { // Still have files to read. Start a new prefetch thread. std::string file_name = file_names_[file_idx]; - std::thread prefetcher( - [this, file_name] { PrefetchThreadFunc(file_name); }); - prefetcher.detach(); + prefetcher = std::thread([this, file_name, thread_idx] { + PrefetchThreadFunc(file_name, thread_idx); + }); } else { // No more file to read. ++completed_thread_num; - if (completed_thread_num == thread_num_) { - thread_quotas_->Close(); - buffer_->Close(); + if (completed_thread_num == prefetchers_.size()) { break; } } } + // If users invoke ReInit() when scheduler is running, it will close the + // 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler + // to release their resource. So a check is needed before scheduler ends. + for (auto& p : prefetchers_) { + if (p.joinable()) { + p.join(); + } + } VLOG(5) << "MultipleReader schedule thread terminates."; } -void MultipleReader::PrefetchThreadFunc(std::string file_name) { +void MultipleReader::PrefetchThreadFunc(std::string file_name, + size_t thread_idx) { VLOG(5) << "The prefetch thread of file '" << file_name << "' starts."; std::unique_ptr reader = CreateReaderByFileName(file_name, dims_); @@ -131,8 +150,10 @@ void MultipleReader::PrefetchThreadFunc(std::string file_name) { break; } } - Quota quota; - thread_quotas_->Send("a); + if (!available_thread_idx_->Send(&thread_idx)) { + VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. " + "Fail to send thread_idx."; + } VLOG(5) << "The prefetch thread of file '" << file_name << "' terminates."; } From f863866471f285015201183994d45dc5637919bb Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 20 Mar 2018 17:54:37 +0800 Subject: [PATCH 5/6] Add an unitest --- .../fluid/operators/reader/open_files_op.cc | 4 +- .../operators/reader/reader_op_registry.cc | 9 ++- .../operators/reader/reader_op_registry.h | 2 +- python/paddle/fluid/layers/io.py | 5 +- .../tests/unittests/test_multiple_reader.py | 71 +++++++++++++++++++ 5 files changed, 82 insertions(+), 9 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_multiple_reader.py diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc index 49cdf5365c9964..1ab4111efe80f5 100644 --- a/paddle/fluid/operators/reader/open_files_op.cc +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -94,7 +94,9 @@ void MultipleReader::EndScheduler() { available_thread_idx_->Close(); buffer_->Close(); waiting_file_idx_->Close(); - scheduler_.join(); + if (scheduler_.joinable()) { + scheduler_.join(); + } delete buffer_; delete available_thread_idx_; delete waiting_file_idx_; diff --git a/paddle/fluid/operators/reader/reader_op_registry.cc b/paddle/fluid/operators/reader/reader_op_registry.cc index 05d79c76d5ab0e..fc8dc747ff0c22 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.cc +++ b/paddle/fluid/operators/reader/reader_op_registry.cc @@ -38,17 +38,16 @@ std::unordered_map& FileReaderRegistry() { std::unique_ptr CreateReaderByFileName( const std::string& file_name, const std::vector& dims) { - size_t separator_pos = file_name.find(kFileFormatSeparator); + size_t separator_pos = file_name.find_last_of(kFileFormatSeparator); PADDLE_ENFORCE_NE(separator_pos, std::string::npos, "File name illegal! A legal file name should be like: " - "[file_format]:[file_name] (e.g., 'recordio:data_file')."); - std::string filetype = file_name.substr(0, separator_pos); - std::string f_name = file_name.substr(separator_pos + 1); + "[file_name].[file_format] (e.g., 'data_file.recordio')."); + std::string filetype = file_name.substr(separator_pos + 1); auto itor = FileReaderRegistry().find(filetype); PADDLE_ENFORCE(itor != FileReaderRegistry().end(), "No file reader registered for '%s' format.", filetype); - framework::ReaderBase* reader = (itor->second)(f_name, dims); + framework::ReaderBase* reader = (itor->second)(file_name, dims); return std::unique_ptr(reader); } diff --git a/paddle/fluid/operators/reader/reader_op_registry.h b/paddle/fluid/operators/reader/reader_op_registry.h index dd19b982dad862..929d32ad8b3678 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.h +++ b/paddle/fluid/operators/reader/reader_op_registry.h @@ -21,7 +21,7 @@ namespace paddle { namespace operators { namespace reader { -static constexpr char kFileFormatSeparator[] = ":"; +static constexpr char kFileFormatSeparator[] = "."; using FileReaderCreator = std::function&)>; diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index 89153f325bed5b..f169642eaa44ea 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -21,7 +21,8 @@ __all__ = [ 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file', - 'read_file', 'create_shuffle_reader', 'create_double_buffer_reader' + 'open_files', 'read_file', 'create_shuffle_reader', + 'create_double_buffer_reader' ] @@ -307,7 +308,7 @@ def open_files(filenames, thread_num, shapes, lod_levels, dtypes): 'shape_concat': shape_concat, 'lod_levels': lod_levels, 'ranks': ranks, - 'filename': filenames, + 'file_names': filenames, 'thread_num': thread_num }) diff --git a/python/paddle/fluid/tests/unittests/test_multiple_reader.py b/python/paddle/fluid/tests/unittests/test_multiple_reader.py new file mode 100644 index 00000000000000..cb1aaaae5a7a45 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_multiple_reader.py @@ -0,0 +1,71 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle.fluid as fluid +import paddle.v2 as paddle +import paddle.v2.dataset.mnist as mnist +from shutil import copyfile + + +class TestMultipleReader(unittest.TestCase): + def setUp(self): + # Convert mnist to recordio file + with fluid.program_guard(fluid.Program(), fluid.Program()): + reader = paddle.batch(mnist.train(), batch_size=32) + feeder = fluid.DataFeeder( + feed_list=[ # order is image and label + fluid.layers.data( + name='image', shape=[784]), + fluid.layers.data( + name='label', shape=[1], dtype='int64'), + ], + place=fluid.CPUPlace()) + self.num_batch = fluid.recordio_writer.convert_reader_to_recordio_file( + './mnist_0.recordio', reader, feeder) + copyfile('./mnist_0.recordio', './mnist_1.recordio') + copyfile('./mnist_0.recordio', './mnist_2.recordio') + print(self.num_batch) + + def test_multiple_reader(self, thread_num=3): + file_list = [ + './mnist_0.recordio', './mnist_1.recordio', './mnist_2.recordio' + ] + with fluid.program_guard(fluid.Program(), fluid.Program()): + data_files = fluid.layers.open_files( + filenames=file_list, + thread_num=thread_num, + shapes=[(-1, 784), (-1, 1)], + lod_levels=[0, 0], + dtypes=['float32', 'int64']) + img, label = fluid.layers.read_file(data_files) + + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + else: + place = fluid.CPUPlace() + + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + batch_count = 0 + while not data_files.eof(): + img_val, = exe.run(fetch_list=[img]) + batch_count += 1 + print(batch_count) + # data_files.reset() + print("FUCK") + + self.assertEqual(batch_count, self.num_batch * 3) From 2532b922dc4897478589d7b4064cde40113f943b Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 20 Mar 2018 19:20:58 +0800 Subject: [PATCH 6/6] Add more unittests and fix bugs --- paddle/fluid/operators/reader/open_files_op.cc | 1 + python/paddle/fluid/tests/unittests/.gitignore | 3 +++ .../tests/unittests/test_multiple_reader.py | 17 ++++++++++------- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc index 1ab4111efe80f5..414c76fea0bb91 100644 --- a/paddle/fluid/operators/reader/open_files_op.cc +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -122,6 +122,7 @@ void MultipleReader::ScheduleThreadFunc() { // No more file to read. ++completed_thread_num; if (completed_thread_num == prefetchers_.size()) { + buffer_->Close(); break; } } diff --git a/python/paddle/fluid/tests/unittests/.gitignore b/python/paddle/fluid/tests/unittests/.gitignore index 6b3fc2a83c649c..ad02bdecf436bb 100644 --- a/python/paddle/fluid/tests/unittests/.gitignore +++ b/python/paddle/fluid/tests/unittests/.gitignore @@ -1 +1,4 @@ mnist.recordio +mnist_0.recordio +mnist_1.recordio +mnist_2.recordio diff --git a/python/paddle/fluid/tests/unittests/test_multiple_reader.py b/python/paddle/fluid/tests/unittests/test_multiple_reader.py index cb1aaaae5a7a45..69f8acf81efaba 100644 --- a/python/paddle/fluid/tests/unittests/test_multiple_reader.py +++ b/python/paddle/fluid/tests/unittests/test_multiple_reader.py @@ -22,9 +22,10 @@ class TestMultipleReader(unittest.TestCase): def setUp(self): + self.batch_size = 64 # Convert mnist to recordio file with fluid.program_guard(fluid.Program(), fluid.Program()): - reader = paddle.batch(mnist.train(), batch_size=32) + reader = paddle.batch(mnist.train(), batch_size=self.batch_size) feeder = fluid.DataFeeder( feed_list=[ # order is image and label fluid.layers.data( @@ -37,9 +38,8 @@ def setUp(self): './mnist_0.recordio', reader, feeder) copyfile('./mnist_0.recordio', './mnist_1.recordio') copyfile('./mnist_0.recordio', './mnist_2.recordio') - print(self.num_batch) - def test_multiple_reader(self, thread_num=3): + def main(self, thread_num): file_list = [ './mnist_0.recordio', './mnist_1.recordio', './mnist_2.recordio' ] @@ -64,8 +64,11 @@ def test_multiple_reader(self, thread_num=3): while not data_files.eof(): img_val, = exe.run(fetch_list=[img]) batch_count += 1 - print(batch_count) - # data_files.reset() - print("FUCK") - + self.assertLessEqual(img_val.shape[0], self.batch_size) + data_files.reset() self.assertEqual(batch_count, self.num_batch * 3) + + def test_main(self): + self.main(thread_num=3) # thread number equals to file number + self.main(thread_num=10) # thread number is larger than file number + self.main(thread_num=2) # thread number is less than file number