-
Notifications
You must be signed in to change notification settings - Fork 18.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added DataReader for parallel training with one DB session
- Makes sure each solver accesses a different subset of the data - Sequential reading of DB, for performance - Prefetches a configurable amount of data to host memory - Distributes data to solvers in round-robin way for determinism
- Loading branch information
Showing
8 changed files
with
260 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
#ifndef CAFFE_DATA_READER_HPP_ | ||
#define CAFFE_DATA_READER_HPP_ | ||
|
||
#include <map> | ||
#include <string> | ||
#include <vector> | ||
|
||
#include "caffe/common.hpp" | ||
#include "caffe/internal_thread.hpp" | ||
#include "caffe/util/blocking_queue.hpp" | ||
#include "caffe/util/db.hpp" | ||
|
||
namespace caffe { | ||
|
||
/** | ||
* @brief Reads data from a source to queues available to data layers. | ||
* A single reading thread is created per source, even if multiple solvers | ||
* are running in parallel, e.g. for multi-GPU training. This makes sure | ||
* databases are read sequentially, and that each solver accesses a different | ||
* subset of the database. Data is distributed to solvers in a round-robin | ||
* way to keep parallel training deterministic. | ||
*/ | ||
class DataReader { | ||
public: | ||
explicit DataReader(const LayerParameter& param); | ||
~DataReader(); | ||
|
||
inline BlockingQueue<Datum*>& free() const { | ||
return queues_->free_; | ||
} | ||
inline BlockingQueue<Datum*>& full() const { | ||
return queues_->full_; | ||
} | ||
|
||
protected: | ||
// Queue pairs are shared between a body and its readers | ||
class QueuePair { | ||
public: | ||
explicit QueuePair(int size); | ||
~QueuePair(); | ||
|
||
BlockingQueue<Datum*> free_; | ||
BlockingQueue<Datum*> full_; | ||
|
||
DISABLE_COPY_AND_ASSIGN(QueuePair); | ||
}; | ||
|
||
// A single body is created per source | ||
class Body : public InternalThread { | ||
public: | ||
explicit Body(const LayerParameter& param); | ||
virtual ~Body(); | ||
|
||
protected: | ||
void InternalThreadEntry(); | ||
void read_one(db::Cursor* cursor, int index); | ||
|
||
const LayerParameter param_; | ||
vector<shared_ptr<QueuePair> > reader_queues_; | ||
|
||
friend class DataReader; | ||
|
||
DISABLE_COPY_AND_ASSIGN(Body); | ||
}; | ||
|
||
// A source is uniquely identified by its layer name + path, in case | ||
// the same database is read from two different locations in the net. | ||
static inline string source_key(const LayerParameter& param) { | ||
return param.name() + ":" + param.data_param().source(); | ||
} | ||
|
||
const shared_ptr<QueuePair> queues_; | ||
shared_ptr<Body> body_; | ||
|
||
static map<const string, boost::weak_ptr<DataReader::Body> > bodies_; | ||
|
||
DISABLE_COPY_AND_ASSIGN(DataReader); | ||
}; | ||
|
||
} // namespace caffe | ||
|
||
#endif // CAFFE_DATA_READER_HPP_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
#include <boost/thread.hpp> | ||
#include <map> | ||
#include <string> | ||
#include <vector> | ||
|
||
#include "caffe/common.hpp" | ||
#include "caffe/data_layers.hpp" | ||
#include "caffe/data_reader.hpp" | ||
#include "caffe/proto/caffe.pb.h" | ||
|
||
namespace caffe { | ||
|
||
using boost::weak_ptr; | ||
|
||
map<const string, weak_ptr<DataReader::Body> > DataReader::bodies_; | ||
static boost::mutex bodies_mutex_; | ||
|
||
// TODO single solver until multi-gpu merge | ||
static const int solver_count = 1; | ||
|
||
DataReader::DataReader(const LayerParameter& param) | ||
: queues_(new QueuePair( // | ||
param.data_param().prefetch() * param.data_param().batch_size())) { | ||
// Get or create a body | ||
boost::mutex::scoped_lock lock(bodies_mutex_); | ||
string key = source_key(param); | ||
weak_ptr<Body>& weak = bodies_[key]; | ||
body_ = weak.lock(); | ||
if (!body_) { | ||
body_.reset(new Body(param)); | ||
bodies_[key] = weak_ptr<Body>(body_); | ||
} | ||
body_->reader_queues_.push_back(queues_); | ||
// Check a single net is trained at a time per process, whether single | ||
// or multi solver. This might also happen if two data layers have same | ||
// name and same source. | ||
CHECK(body_->reader_queues_.size() <= solver_count); | ||
} | ||
|
||
DataReader::~DataReader() { | ||
string key = source_key(body_->param_); | ||
body_.reset(); | ||
boost::mutex::scoped_lock lock(bodies_mutex_); | ||
if (bodies_[key].expired()) { | ||
bodies_.erase(key); | ||
} | ||
} | ||
|
||
// | ||
|
||
DataReader::QueuePair::QueuePair(int size) { | ||
// Initialize the free queue with requested number of datums | ||
for (int i = 0; i < size; ++i) { | ||
free_.push(new Datum()); | ||
} | ||
} | ||
|
||
DataReader::QueuePair::~QueuePair() { | ||
Datum* datum; | ||
while (free_.try_pop(&datum)) { | ||
delete datum; | ||
} | ||
while (full_.try_pop(&datum)) { | ||
delete datum; | ||
} | ||
} | ||
|
||
// | ||
|
||
DataReader::Body::Body(const LayerParameter& param) | ||
: param_(param), | ||
reader_queues_() { | ||
StartInternalThread(); | ||
} | ||
|
||
DataReader::Body::~Body() { | ||
StopInternalThread(); | ||
} | ||
|
||
void DataReader::Body::InternalThreadEntry() { | ||
shared_ptr<db::DB> db(db::GetDB(param_.data_param().backend())); | ||
db->Open(param_.data_param().source(), db::READ); | ||
shared_ptr<db::Cursor> cursor(db->NewCursor()); | ||
try { | ||
// Synchronize with main thread to make sure we see at least one queue | ||
{ | ||
boost::mutex::scoped_lock lock(bodies_mutex_); | ||
CHECK_GE(reader_queues_.size(), 1); | ||
} | ||
// To ensure deterministic runs, only start running once all solvers | ||
// are ready. But solvers need to peek on one item during initialization, | ||
// so to allow the root solver to start before the other solvers are | ||
// created, read one item. | ||
int index = 0; | ||
if (param_.phase() == TRAIN) { | ||
read_one(cursor.get(), index++); | ||
|
||
// Wait on remaining solvers | ||
while (!must_stop()) { | ||
usleep(100 * 1000); | ||
boost::mutex::scoped_lock lock(bodies_mutex_); | ||
if (reader_queues_.size() == solver_count) { | ||
break; | ||
} | ||
} | ||
} | ||
// Main loop | ||
while (!must_stop()) { | ||
if (index == reader_queues_.size()) { | ||
index = 0; | ||
} | ||
read_one(cursor.get(), index++); | ||
} | ||
} catch (boost::thread_interrupted&) { | ||
// Interrupted exception is expected on shutdown | ||
} | ||
} | ||
|
||
void DataReader::Body::read_one(db::Cursor* cursor, int index) { | ||
Datum* datum = reader_queues_[index]->free_.pop(); | ||
// TODO deserialize in-place instead of copy? | ||
datum->ParseFromString(cursor->value()); | ||
reader_queues_[index]->full_.push(datum); | ||
|
||
// go to the next iter | ||
cursor->Next(); | ||
if (!cursor->valid()) { | ||
DLOG(INFO) << "Restarting data prefetching from start."; | ||
cursor->SeekToFirst(); | ||
} | ||
} | ||
|
||
} // namespace caffe |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.