Skip to content

Commit

Permalink
Fix SRL hang when exit. (#291)
Browse files Browse the repository at this point in the history
* Fix SRL hang when exit.

* Error occurred when enable Async Load in TestDataProvider.
  * It because DataProvider is calling getNextBatchInternal in one thread, and destructing DataProvider in other thread.
  * Add wait routine in DataProvider destructing.
* Also fix another bug, when destructing TestDataProvider and do not read any test data.

Fix #286

* Follow comments, Use mutex is cool!
  • Loading branch information
reyoung authored and hedaoyuan committed Nov 7, 2016
1 parent c64cd6f commit e05f4ff
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 3 deletions.
10 changes: 10 additions & 0 deletions demo/semantic_role_labeling/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
*.pyc
train.log
data/feature
data/conll05st-release/
data/src.dict
data/test.wsj.props
data/test.wsj.seq_pair
data/test.wsj.words
data/tgt.dict
output
3 changes: 2 additions & 1 deletion paddle/gserver/dataproviders/DataProvider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,10 @@ void DoubleBuffer::asyncLoadBatch() {
taskReadySem_.wait();
if (stopping_) break;

while (batchSize_ == 0) {
while (batchSize_ == 0 && !stopping_) {
usleep(5);
}
if (stopping_) break;

do {
DataBatch newBatch;
Expand Down
19 changes: 17 additions & 2 deletions paddle/gserver/dataproviders/PyDataProvider2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,26 +433,34 @@ class PyDataProvider2 : public DataProvider {

inline void resetImpl(bool startNewThread) {
DBG << "Reseting " << startNewThread;
exit_.store(true);
if (loadThread_) { // is loading.
exit_.store(true);
loadThread_->join();
loadThread_.reset();
}
{
PyGuard g;
callingContexts_.clear();
this->pullCV_.notify_one();
}

std::lock_guard<std::mutex> guard(mutexForReset_);
{
PyGuard g;
dataPool_.clear();
}
poolActualSize_ = 0;
exit_ = false;

if (startNewThread && cache_->reset()) {
DBG << "Start new thread.";
loadThread_.reset(new std::thread([this] {
exit_ = false;
loadThread();
}));
callingContextCreated_.wait();
}
DBG << "Reset done";
exit_ = false;
}

private:
Expand All @@ -465,6 +473,8 @@ class PyDataProvider2 : public DataProvider {
std::condition_variable pullCV_;
std::mutex mtx_;

std::mutex mutexForReset_;

ThreadBarrier callingContextCreated_;
std::unique_ptr<IPyDataProviderCache> cache_;

Expand Down Expand Up @@ -529,6 +539,7 @@ class PyDataProvider2 : public DataProvider {
* Loading a batch of data.
*/
int64_t getNextBatchInternal(int64_t size_, DataBatch *batch) {
std::lock_guard<std::mutex> guard(mutexForReset_);
REGISTER_TIMER("PyDP2.getNextBatchInternal")
CHECK_GE(size_, 0);
size_t size = (size_t) size_;
Expand All @@ -554,6 +565,10 @@ class PyDataProvider2 : public DataProvider {
} else { // loading from cache.
poolPtr = this->cache_->load();
}
if (exit_) {
// PyDataProvider is destructing.
return 0;
}
CHECK(poolPtr != nullptr);

std::deque<PyObjectPtr>& pool = *poolPtr;
Expand Down
17 changes: 17 additions & 0 deletions paddle/gserver/tests/test_PyDataProvider2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,23 @@ TEST(PyDataProvider2, test_check) {
}
}

TEST(PyDataProvider2, multiThread) {
paddle::DataConfig config;
config.set_type("py2");
config.set_files(FLAGS_train_list.c_str());
config.set_load_data_module("test_PyDataProvider2");
config.set_load_data_object("test_dense_no_seq");
config.set_async_load_data(true);

std::unique_ptr<paddle::DataProvider> provider(
paddle::DataProvider::create(config, false));
provider->reset();
paddle::DataBatch batch;
provider->getNextBatch(100, &batch);
provider->reset();
provider.reset();
}

int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
paddle::initMain(argc, argv);
Expand Down

0 comments on commit e05f4ff

Please sign in to comment.