Skip to content

Commit

Permalink
[GPUPS]Fix dataset (PaddlePaddle#36)
Browse files Browse the repository at this point in the history
* revert pipeline pull

* fix conflict

* fix conflict

* fix conflict

* add jvm.so

* Revert "pipeline build (#9)"

This reverts commit 869c43f.

* revert async build pull

* fix dataset

* fix dataset
  • Loading branch information
zmxdream authored Jul 6, 2022
1 parent 61be085 commit cdd3beb
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
14 changes: 11 additions & 3 deletions paddle/fluid/framework/fleet/ps_gpu_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,14 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
int remain = 0;
size_t begin = 0;

std::string data_set_name = std::string(typeid(*dataset_).name());
dataset_mutex_.lock();
Dataset* cur_dataset = dataset_pipe_.front();
dataset_pipe_.pop();
dataset_mutex_.unlock();
std::string data_set_name = std::string(typeid(*cur_dataset).name());

if (data_set_name.find("SlotRecordDataset") != std::string::npos) {
SlotRecordDataset* dataset = dynamic_cast<SlotRecordDataset*>(dataset_);
SlotRecordDataset* dataset = dynamic_cast<SlotRecordDataset*>(cur_dataset);
auto input_channel = dataset->GetInputChannel();
VLOG(3) << "buildtask::inputslotchannle size: "
<< input_channel->Size();
Expand Down Expand Up @@ -958,7 +962,11 @@ void PSGPUWrapper::LoadIntoMemory(bool is_shuffle) {
InitSlotInfo();
std::shared_ptr<HeterContext> gpu_task = gpu_task_pool_.Get();
gpu_task->Reset();


dataset_mutex_.lock();
dataset_pipe_.push(dataset_);
dataset_mutex_.unlock();

data_ready_channel_->Put(gpu_task);

VLOG(3) << "End LoadIntoMemory(), dataset[" << dataset_ << "]";
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/framework/fleet/ps_gpu_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,12 @@ class PSGPUWrapper {
private:
static std::shared_ptr<PSGPUWrapper> s_instance_;
Dataset* dataset_;

//当load数据完成后,会将其筛入到如下队列,后续异步pull会用到这个队列的数据
//因为load 和 异步build是两个线程,所以才需要下面的队列来解耦这个dataset对象
std::queue<Dataset*> dataset_pipe_;
std::mutex dataset_mutex_;

#ifdef PADDLE_WITH_PSLIB
paddle::ps::AfsApiWrapper afs_handler_;
#endif
Expand Down

0 comments on commit cdd3beb

Please sign in to comment.