From d6c50e80f8615a2d11f523cd1ebb40ee5e074abb Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Thu, 22 Sep 2022 19:17:24 +0800 Subject: [PATCH] add random settings --- paddle/fluid/framework/data_feed.cu | 25 +++++++++++++++++++ .../fleet/heter_ps/graph_gpu_ps_table_inl.cu | 2 +- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/data_feed.cu b/paddle/fluid/framework/data_feed.cu index 148446b8075f8..ef247ede91a08 100644 --- a/paddle/fluid/framework/data_feed.cu +++ b/paddle/fluid/framework/data_feed.cu @@ -22,6 +22,7 @@ limitations under the License. */ #include #include #include +#include #include #include "cub/cub.cuh" #include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h" @@ -85,6 +86,18 @@ __global__ void FillSlotValueOffsetKernel(const int ins_num, } } +__global__ void shuffle_array(uint64_t* keys, + size_t len) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + curandState rng; + curand_init(clock64(), threadIdx.x, 0, &rng); + if (i < len) { + const size_t j = curand(&rng) % len; + keys[i] = atomicExch(reinterpret_cast(keys + j), + static_cast(keys[i])); + } +} + __global__ void fill_actual_neighbors(int64_t* vals, int64_t* actual_vals, int64_t* actual_vals_dst, @@ -1412,6 +1425,18 @@ void GraphDataGenerator::AllocResource(const paddle::platform::Place &place, // CUDA_CHECK(cudaMemcpyAsync(d_device_keys_->ptr(), h_device_keys_->data(), // device_key_size_ * sizeof(int64_t), // cudaMemcpyHostToDevice, stream_)); + + // Shuffle start_nodes(d_device_keys_) when training. + if (gpu_graph_training_) { + for (size_t i = 0; i < h_device_keys_.size(); i++) { + uint64_t* device_ptr = reinterpret_cast(d_device_keys_[i]->ptr()); + shuffle_array<<size()), + CUDA_NUM_THREADS, + 0, + stream_>>>(device_ptr, h_device_keys_[i]->size()); + } + } + size_t once_max_sample_keynum = walk_degree_ * once_sample_startid_len_; d_prefix_sum_ = memory::AllocShared(place_, (once_max_sample_keynum + 1) * sizeof(int)); diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu index 7f92037a1acb4..a2f5dcae7d051 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu @@ -129,7 +129,7 @@ __global__ void neighbor_sample_kernel(GpuPsCommGraph graph, int i = blockIdx.x * TILE_SIZE + threadIdx.y; const int last_idx = min(static_cast(blockIdx.x + 1) * TILE_SIZE, n); curandState rng; - curand_init(blockIdx.x, threadIdx.y * WARP_SIZE + threadIdx.x, 0, &rng); + curand_init(clock64(), threadIdx.y * WARP_SIZE + threadIdx.x, 0, &rng); while (i < last_idx) { if (node_info_list[i].neighbor_size == 0) { actual_size[i] = default_value;