Skip to content

Commit

Permalink
Merge pull request #12 from qingshui/paddlebox
Browse files Browse the repository at this point in the history
add enable_dense_nccl_barrier flags and add cuda_check remove threadpool catch exception
  • Loading branch information
qingshui authored Jul 7, 2021
2 parents f720b3a + 0e6042c commit 026d37d
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 28 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/framework/boxps_trainer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ void BoxPSTrainer::Run() {

void BoxPSTrainer::Finalize() {
for (auto& th : wait_futures_) {
th.wait();
th.get();
}
if (async_mode_) {
// must be after train thread, otherwise the ps_buffer_ will be closed first
Expand Down
21 changes: 13 additions & 8 deletions paddle/fluid/framework/data_feed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2438,8 +2438,9 @@ void SlotPaddleBoxDataFeed::BuildSlotBatchGPU(const int ins_num) {

copy_timer_.Resume();
// copy index
cudaMemcpy(offsets.data(), d_slot_offsets, slot_total_num * sizeof(size_t),
cudaMemcpyDeviceToHost);
CUDA_CHECK(cudaMemcpy(offsets.data(), d_slot_offsets,
slot_total_num * sizeof(size_t),
cudaMemcpyDeviceToHost));
copy_timer_.Pause();

data_timer_.Resume();
Expand Down Expand Up @@ -2503,8 +2504,9 @@ void SlotPaddleBoxDataFeed::BuildSlotBatchGPU(const int ins_num) {

trans_timer_.Resume();
void** dest_gpu_p = reinterpret_cast<void**>(pack_->slot_buf_ptr());
cudaMemcpy(dest_gpu_p, h_tensor_ptrs.data(), use_slot_size_ * sizeof(void*),
cudaMemcpyHostToDevice);
CUDA_CHECK(cudaMemcpy(dest_gpu_p, h_tensor_ptrs.data(),
use_slot_size_ * sizeof(void*),
cudaMemcpyHostToDevice));

CopyForTensor(ins_num, use_slot_size_, dest_gpu_p,
(const size_t*)pack_->gpu_slot_offsets(),
Expand Down Expand Up @@ -3610,7 +3612,10 @@ MiniBatchGpuPack::MiniBatchGpuPack(const paddle::platform::Place& place,
slot_buf_ptr_ = memory::AllocShared(place_, used_slot_size_ * sizeof(void*));

int device_id = boost::get<platform::CUDAPlace>(place).GetDeviceId();
VLOG(3) << "begin get batch pack device id: " << device_id;
qvalue_tensor_ = &BoxWrapper::GetInstance()->GetQTensor(device_id);
// sync
CUDA_CHECK(cudaStreamSynchronize(stream_));
}

MiniBatchGpuPack::~MiniBatchGpuPack() {}
Expand Down Expand Up @@ -3853,7 +3858,7 @@ void MiniBatchGpuPack::transfer_to_gpu(void) {
copy_host2device(&value_.d_float_lens, buf_.h_float_lens);
copy_host2device(&value_.d_float_keys, buf_.h_float_keys);
copy_host2device(&value_.d_float_offset, buf_.h_float_offset);
cudaStreamSynchronize(stream_);
CUDA_CHECK(cudaStreamSynchronize(stream_));
trans_timer_.Pause();
}

Expand All @@ -3878,9 +3883,9 @@ void MiniBatchGpuPack::pack_qvalue(void) {

float* tensor_ptr =
qvalue_tensor_->mutable_data<float>({len, 1}, this->place_);
cudaMemcpyAsync(tensor_ptr, &qvalue[0], len * sizeof(float),
cudaMemcpyHostToDevice, stream_);
cudaStreamSynchronize(stream_);
CUDA_CHECK(cudaMemcpyAsync(tensor_ptr, &qvalue[0], len * sizeof(float),
cudaMemcpyHostToDevice, stream_));
CUDA_CHECK(cudaStreamSynchronize(stream_));
}

// store pcoc q value
Expand Down
17 changes: 10 additions & 7 deletions paddle/fluid/framework/data_feed.h
Original file line number Diff line number Diff line change
Expand Up @@ -1287,6 +1287,7 @@ struct UsedSlotGpuType {
int is_uint64_value;
int slot_value_idx;
};
#define CUDA_CHECK(val) CHECK(val == cudaSuccess)
template <typename T>
struct CudaBuffer {
T* cu_buffer;
Expand All @@ -1301,11 +1302,12 @@ struct CudaBuffer {
uint64_t size() { return buf_size; }
void malloc(uint64_t size) {
buf_size = size;
cudaMalloc(reinterpret_cast<void**>(&cu_buffer), size * sizeof(T));
CUDA_CHECK(
cudaMalloc(reinterpret_cast<void**>(&cu_buffer), size * sizeof(T)));
}
void free() {
if (cu_buffer != NULL) {
cudaFree(cu_buffer);
CUDA_CHECK(cudaFree(cu_buffer));
cu_buffer = NULL;
}
buf_size = 0;
Expand Down Expand Up @@ -1341,12 +1343,13 @@ struct HostBuffer {
const T& operator[](size_t i) const { return host_buffer[i]; }
void malloc(size_t len) {
buf_size = len;
cudaHostAlloc(reinterpret_cast<void**>(&host_buffer), buf_size * sizeof(T),
cudaHostAllocDefault);
CUDA_CHECK(cudaHostAlloc(reinterpret_cast<void**>(&host_buffer),
buf_size * sizeof(T), cudaHostAllocDefault));
CHECK(host_buffer != NULL);
}
void free() {
if (host_buffer != NULL) {
cudaFreeHost(host_buffer);
CUDA_CHECK(cudaFreeHost(host_buffer));
host_buffer = NULL;
}
buf_size = 0;
Expand Down Expand Up @@ -1471,8 +1474,8 @@ class MiniBatchGpuPack {
return;
}
buf->resize(size);
cudaMemcpyAsync(buf->data(), val, size * sizeof(T), cudaMemcpyHostToDevice,
stream_);
CUDA_CHECK(cudaMemcpyAsync(buf->data(), val, size * sizeof(T),
cudaMemcpyHostToDevice, stream_));
}
template <typename T>
void copy_host2device(CudaBuffer<T>* buf, const HostBuffer<T>& val) {
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/data_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1571,7 +1571,7 @@ void PadBoxSlotDataset::PreLoadIntoMemory() {
}
void PadBoxSlotDataset::WaitPreLoadDone() {
for (auto& f : wait_futures_) {
f.wait();
f.get();
}
if (data_consumer_ != nullptr) {
delete reinterpret_cast<PadBoxSlotDataConsumer*>(data_consumer_);
Expand Down Expand Up @@ -1619,7 +1619,7 @@ void PadBoxSlotDataset::LoadIntoMemory() {
}
// wait all thread finish
for (auto& f : wait_futures_) {
f.wait();
f.get();
}

if (data_consumer_ != nullptr) {
Expand Down
6 changes: 2 additions & 4 deletions paddle/fluid/framework/fleet/box_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ limitations under the License. */

DECLARE_int32(fix_dayid);
DECLARE_bool(padbox_auc_runner_mode);
#ifdef PADDLE_WITH_BOX_PS
DECLARE_bool(enable_sparse_push_barrier);
#endif
DECLARE_bool(enable_dense_nccl_barrier);

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -615,7 +613,7 @@ class BoxWrapper {
}
if (flag & 0x02) {
if (pause) {
if (FLAGS_enable_sparse_push_barrier) {
if (FLAGS_enable_dense_nccl_barrier) {
boxps::MPICluster::Ins().barrier();
}
dev.dense_sync_timer.Pause();
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/threadpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,13 @@ class ThreadPool {
try {
fn();
} catch (platform::EnforceNotMet& ex) {
CHECK(false) << "Unexpected exception is catched in thread pool: "
<< ex.what();
return std::unique_ptr<platform::EnforceNotMet>(
new platform::EnforceNotMet(ex));
} catch (const std::exception& e) {
CHECK(false) << "Unexpected exception is catched in thread pool: "
<< e.what();
PADDLE_THROW(platform::errors::Fatal(
"Unexpected exception is catched in thread pool. All "
"throwable exception in Paddle should be an EnforceNotMet."
Expand Down
14 changes: 9 additions & 5 deletions paddle/fluid/platform/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -477,11 +477,11 @@ DEFINE_double(local_exe_sub_scope_limit, 256.0, // MBytes
DEFINE_int32(fix_dayid, 0, "Whether fix dayid in PaddleBox");
DEFINE_int32(padbox_record_pool_max_size, 2000000,
"PadBoxSlotDataset slot record pool max size");
DEFINE_int32(padbox_dataset_shuffle_thread_num, 10,
DEFINE_int32(padbox_dataset_shuffle_thread_num, 20,
"PadBoxSlotDataset shuffle thread num");
DEFINE_int32(padbox_dataset_merge_thread_num, 10,
DEFINE_int32(padbox_dataset_merge_thread_num, 20,
"PadBoxSlotDataset shuffle thread num");
DEFINE_int32(padbox_slotpool_thread_num, 1,
DEFINE_int32(padbox_slotpool_thread_num, 20,
"PadBoxSlotDataset slot pool thread num");
DEFINE_bool(use_gpu_replica_cache, false,
"if true ,will open use_gpu_replica_cache");
Expand All @@ -495,9 +495,11 @@ DEFINE_bool(padbox_dataset_disable_polling, false,
DEFINE_bool(padbox_dataset_enable_unrollinstance, false,
"if true ,will enable unrollinstance");
DEFINE_bool(lineid_have_extend_info, false,
"if true , will split line id by space into 2 part, the second part will dump at the last of line");
"if true , will split line id by space into 2 part, the second "
"part will dump at the last of line");
DEFINE_bool(dump_filed_same_as_aibox, false,
"if true , will change dump format from abc.tmp0:2:1:1 into abc:1:1, which same as aibox");
"if true , will change dump format from abc.tmp0:2:1:1 into "
"abc:1:1, which same as aibox");

/**
* MKLDNN related FLAG
Expand Down Expand Up @@ -592,3 +594,5 @@ DEFINE_bool(enable_sync_dense_moment, false,
"enable sync dense moment, default false");
DEFINE_bool(enable_ins_parser_file, false,
"enable parser ins file , default false");
DEFINE_bool(enable_dense_nccl_barrier, false,
"enable dense nccl barrier , default false");
2 changes: 1 addition & 1 deletion paddle/fluid/platform/gpu_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ void GpuMemcpyAsync(void *dst, const void *src, size_t count,

void GpuMemcpySync(void *dst, const void *src, size_t count,
enum cudaMemcpyKind kind) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemcpy(dst, src, count, kind));
CHECK(cudaMemcpy(dst, src, count, kind) == cudaSuccess);
}

void GpuMemcpyPeerAsync(void *dst, int dst_device, const void *src,
Expand Down
1 change: 1 addition & 0 deletions python/paddle/fluid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def __bootstrap__():
'padbox_dataset_enable_unrollinstance',
'enable_binding_train_cpu',
'enable_ins_parser_file',
'enable_dense_nccl_barrier',
]
core.init_gflags(["--tryfromenv=" + ",".join(read_env_flags)])
core.init_glog(sys.argv[0])
Expand Down

0 comments on commit 026d37d

Please sign in to comment.