Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add random settings #125

Open
wants to merge 1 commit into
base: gpugraph
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions paddle/fluid/framework/data_feed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License. */
#include <thrust/device_vector.h>
#include <thrust/random.h>
#include <thrust/shuffle.h>
#include <curand_kernel.h>
#include <sstream>
#include "cub/cub.cuh"
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h"
Expand Down Expand Up @@ -85,6 +86,18 @@ __global__ void FillSlotValueOffsetKernel(const int ins_num,
}
}

__global__ void shuffle_array(uint64_t* keys,
size_t len) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
curandState rng;
curand_init(clock64(), threadIdx.x, 0, &rng);
if (i < len) {
const size_t j = curand(&rng) % len;
keys[i] = atomicExch(reinterpret_cast<unsigned long long int*>(keys + j),
static_cast<unsigned long long int>(keys[i]));
}
}

__global__ void fill_actual_neighbors(int64_t* vals,
int64_t* actual_vals,
int64_t* actual_vals_dst,
Expand Down Expand Up @@ -1412,6 +1425,18 @@ void GraphDataGenerator::AllocResource(const paddle::platform::Place &place,
// CUDA_CHECK(cudaMemcpyAsync(d_device_keys_->ptr(), h_device_keys_->data(),
// device_key_size_ * sizeof(int64_t),
// cudaMemcpyHostToDevice, stream_));

// Shuffle start_nodes(d_device_keys_) when training.
if (gpu_graph_training_) {
for (size_t i = 0; i < h_device_keys_.size(); i++) {
uint64_t* device_ptr = reinterpret_cast<uint64_t*>(d_device_keys_[i]->ptr());
shuffle_array<<<GET_BLOCKS(h_device_keys_[i]->size()),
CUDA_NUM_THREADS,
0,
stream_>>>(device_ptr, h_device_keys_[i]->size());
}
}

size_t once_max_sample_keynum = walk_degree_ * once_sample_startid_len_;
d_prefix_sum_ =
memory::AllocShared(place_, (once_max_sample_keynum + 1) * sizeof(int));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ __global__ void neighbor_sample_kernel(GpuPsCommGraph graph,
int i = blockIdx.x * TILE_SIZE + threadIdx.y;
const int last_idx = min(static_cast<int>(blockIdx.x + 1) * TILE_SIZE, n);
curandState rng;
curand_init(blockIdx.x, threadIdx.y * WARP_SIZE + threadIdx.x, 0, &rng);
curand_init(clock64(), threadIdx.y * WARP_SIZE + threadIdx.x, 0, &rng);
while (i < last_idx) {
if (node_info_list[i].neighbor_size == 0) {
actual_size[i] = default_value;
Expand Down