Skip to content

Commit

Permalink
Merge pull request jack603047588#62 from qingshui/paddlebox
Browse files Browse the repository at this point in the history
add hbm garbage gc
  • Loading branch information
qingshui authored Apr 25, 2023
2 parents 007dda6 + f24d2c3 commit e0b3ffe
Show file tree
Hide file tree
Showing 19 changed files with 607 additions and 337 deletions.
248 changes: 170 additions & 78 deletions paddle/fluid/framework/boxps_worker.cc

Large diffs are not rendered by default.

28 changes: 20 additions & 8 deletions paddle/fluid/framework/data_feed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2725,9 +2725,9 @@ void SlotRecordInMemoryDataFeed::BuildSlotBatchGPU(const int ins_num) {

size_t* d_slot_offsets = reinterpret_cast<size_t*>(pack_->gpu_slot_offsets());

HostBuffer<size_t>& offsets = pack_->offsets();
auto& offsets = pack_->offsets();
offsets.resize(slot_total_num);
HostBuffer<void*>& h_tensor_ptrs = pack_->h_tensor_ptrs();
auto& h_tensor_ptrs = pack_->h_tensor_ptrs();
h_tensor_ptrs.resize(use_slot_size_);
// alloc gpu memory
pack_->resize_tensor();
Expand Down Expand Up @@ -3120,6 +3120,7 @@ void SlotPaddleBoxDataFeed::GetUsedSlotIndex(
// get feasigns that FeedPass doesn't need
const std::unordered_set<std::string>& slot_name_omited_in_feedpass_ =
boxps_ptr->GetOmitedSlot();
const std::vector<int>& slot_ids = boxps_ptr->GetSlotVector();
used_slot_index->clear();
for (int i = 0; i < use_slot_size_; ++i) {
auto& info = used_slots_info_[i];
Expand All @@ -3129,10 +3130,16 @@ void SlotPaddleBoxDataFeed::GetUsedSlotIndex(
if (!is_slot_values(info.slot)) {
continue;
}
if (slot_name_omited_in_feedpass_.find(info.slot) ==
if (slot_name_omited_in_feedpass_.find(info.slot) !=
slot_name_omited_in_feedpass_.end()) {
used_slot_index->push_back(info.slot_value_idx);
continue;
}
int slot_id = atoi(info.slot.c_str());
if (slot_ids.end() ==
std::find(slot_ids.begin(), slot_ids.end(), slot_id)){
continue;
}
used_slot_index->push_back(info.slot_value_idx);
}
}
bool SlotPaddleBoxDataFeed::Start() {
Expand Down Expand Up @@ -3420,11 +3427,16 @@ void SlotPaddleBoxDataFeed::BuildSlotBatchGPU(const int ins_num) {
int64_t uint64_offset = 0;
offset_timer_.Pause();

auto stream = dynamic_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(this->place_))
->stream();

copy_timer_.Resume();
// copy index
CUDA_CHECK(cudaMemcpy(offsets.data(), d_slot_offsets,
CUDA_CHECK(cudaMemcpyAsync(offsets.data(), d_slot_offsets,
slot_total_num * sizeof(size_t),
cudaMemcpyDeviceToHost));
cudaMemcpyDeviceToHost, stream));
cudaStreamSynchronize(stream);
copy_timer_.Pause();
data_timer_.Resume();

Expand Down Expand Up @@ -3490,9 +3502,9 @@ void SlotPaddleBoxDataFeed::BuildSlotBatchGPU(const int ins_num) {

trans_timer_.Resume();
void** dest_gpu_p = reinterpret_cast<void**>(pack_->slot_buf_ptr());
CUDA_CHECK(cudaMemcpy(dest_gpu_p, h_tensor_ptrs.data(),
CUDA_CHECK(cudaMemcpyAsync(dest_gpu_p, h_tensor_ptrs.data(),
use_slot_size_ * sizeof(void*),
cudaMemcpyHostToDevice));
cudaMemcpyHostToDevice, stream));

CopyForTensor(
ins_num, use_slot_size_, dest_gpu_p, pack_->gpu_slot_offsets(),
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/framework/data_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2006,12 +2006,14 @@ void PadBoxSlotDataset::CheckThreadPool(void) {
} else { // shuffle
VLOG(0) << "pass id=" << pass_id_ << ", use shuffle by random id";
}
VLOG(0) << "pass id=" << pass_id_ << ", shuffle disable: " << disable_shuffle_
<< ", polling disable: " << disable_polling_;
used_fea_index_.clear();
auto feed_obj = reinterpret_cast<SlotPaddleBoxDataFeed*>(readers_[0].get());
feed_obj->GetUsedSlotIndex(&used_fea_index_);

VLOG(0) << "pass id=" << pass_id_ << ", shuffle disable: " << disable_shuffle_
<< ", polling disable: " << disable_polling_
<< ", slot num=" << used_fea_index_.size();

// read ins thread
thread_pool_ = GetThreadPool(thread_num_);
// merge thread
Expand Down
Loading

0 comments on commit e0b3ffe

Please sign in to comment.