Skip to content

Commit

Permalink
merge develop
Browse files Browse the repository at this point in the history
  • Loading branch information
Thunderbrook committed May 17, 2022
2 parents ac1eeb7 + 9b15efc commit e3b8377
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 90 deletions.
205 changes: 116 additions & 89 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ sample_result is to save the neighbor sampling result, its size is len *
sample_size;
*/

__global__ void get_cpu_id_index(int64_t* key, int64_t* val, int64_t* cpu_key,
int* sum, int* index, int len) {
__global__ void get_cpu_id_index(int64_t* key, int* actual_sample_size,
int64_t* cpu_key, int* sum, int* index,
int len) {
CUDA_KERNEL_LOOP(i, len) {
if (val[i] == -1) {
if (actual_sample_size[i] == -1) {
int old = atomicAdd(sum, 1);
cpu_key[old] = key[i];
index[old] = i;
Expand All @@ -44,11 +45,35 @@ __global__ void get_cpu_id_index(int64_t* key, int64_t* val, int64_t* cpu_key,
}
}

__global__ void get_actual_gpu_ac(int* gpu_ac, int number_on_cpu) {
CUDA_KERNEL_LOOP(i, number_on_cpu) { gpu_ac[i] /= sizeof(int64_t); }
}

template <int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE>
__global__ void copy_buffer_ac_to_final_place(
int64_t* gpu_buffer, int* gpu_ac, int64_t* val, int* actual_sample_size,
int* index, int* cumsum_gpu_ac, int number_on_cpu, int sample_size) {
assert(blockDim.x == WARP_SIZE);
assert(blockDim.y == BLOCK_WARPS);

int i = blockIdx.x * TILE_SIZE + threadIdx.y;
const int last_idx =
min(static_cast<int>(blockIdx.x + 1) * TILE_SIZE, number_on_cpu);
while (i < last_idx) {
actual_sample_size[index[i]] = gpu_ac[i];
for (int j = threadIdx.x; j < gpu_ac[i]; j += WARP_SIZE) {
val[index[i] * sample_size + j] = gpu_buffer[cumsum_gpu_ac[i] + j];
}
i += BLOCK_WARPS;
}
}

template <int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE>
__global__ void neighbor_sample_example_v2(GpuPsCommGraph graph,
int64_t* node_index,
int* actual_size, int64_t* res,
int sample_len, int n) {
int sample_len, int n,
int default_value) {
assert(blockDim.x == WARP_SIZE);
assert(blockDim.y == BLOCK_WARPS);

Expand All @@ -59,7 +84,7 @@ __global__ void neighbor_sample_example_v2(GpuPsCommGraph graph,

while (i < last_idx) {
if (node_index[i] == -1) {
actual_size[i] = 0;
actual_size[i] = default_value;
i += BLOCK_WARPS;
continue;
}
Expand Down Expand Up @@ -765,6 +790,10 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2(
auto d_right = memory::Alloc(place, total_gpu * sizeof(int));
int* d_left_ptr = reinterpret_cast<int*>(d_left->ptr());
int* d_right_ptr = reinterpret_cast<int*>(d_right->ptr());
int default_value = 0;
if (cpu_query_switch) {
default_value = -1;
}
cudaMemsetAsync(d_left_ptr, -1, total_gpu * sizeof(int), stream);
cudaMemsetAsync(d_right_ptr, -1, total_gpu * sizeof(int), stream);
Expand Down Expand Up @@ -804,14 +833,8 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2(
walk_to_dest(gpu_id, total_gpu, h_left, h_right,
(uint64_t*)(d_shard_keys_ptr), NULL);
// For cpu_query_switch, we need global items.
std::vector<thrust::device_vector<int64_t>> cpu_keys_list;
std::vector<thrust::device_vector<int>> cpu_index_list;
thrust::device_vector<int64_t> tmp1;
thrust::device_vector<int> tmp2;
for (int i = 0; i < total_gpu; ++i) {
if (h_left[i] == -1) {
// Insert empty object
continue;
}
int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1;
Expand Down Expand Up @@ -840,108 +863,112 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2(
WARP_SIZE, BLOCK_WARPS,
TILE_SIZE><<<grid, block, 0, resource_->remote_stream(i, gpu_id)>>>(
graph, id_array, actual_size_array, sample_array, sample_size,
shard_len);
// cpu_graph_table->random_sample_neighbors
// if (cpu_query_switch) {
//}
shard_len, default_value);
}
for (int i = 0; i < total_gpu; ++i) {
if (h_left[i] == -1) {
if (cpu_query_switch) {
cpu_keys_list.emplace_back(tmp1);
cpu_index_list.emplace_back(tmp2);
}
continue;
}
cudaStreamSynchronize(resource_->remote_stream(i, gpu_id));
if (cpu_query_switch) {
platform::CUDADeviceGuard guard(resource_->dev_id(i));
int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1;
auto& node = path_[gpu_id][i].nodes_.back();
int64_t* id_array = reinterpret_cast<int64_t*>(node.val_storage);
int* actual_size_array = (int*)(id_array + shard_len);
int64_t* sample_array =
(int64_t*)(actual_size_array + shard_len + shard_len % 2);
thrust::device_vector<int64_t> cpu_keys_ptr(shard_len);
thrust::device_vector<int> index_ptr(shard_len + 1, 0);
int64_t* node_id_array = reinterpret_cast<int64_t*>(node.key_storage);
int grid_size2 = (shard_len - 1) / block_size_ + 1;
get_cpu_id_index<<<grid_size2, block_size_, 0,
resource_->remote_stream(i, gpu_id)>>>(
node_id_array, id_array,
thrust::raw_pointer_cast(cpu_keys_ptr.data()),
thrust::raw_pointer_cast(index_ptr.data()),
thrust::raw_pointer_cast(index_ptr.data()) + 1, shard_len);
cudaStreamSynchronize(resource_->remote_stream(i, gpu_id));
cpu_keys_list.emplace_back(cpu_keys_ptr);
cpu_index_list.emplace_back(index_ptr);
}
}
if (cpu_query_switch) {
for (int i = 0; i < total_gpu; ++i) {
if (h_left[i] == -1) {
continue;
}
platform::CUDADeviceGuard guard(resource_->dev_id(i));
auto shard_len = h_right[i] - h_left[i] + 1;
int* cpu_index = new int[shard_len + 1];
cudaMemcpy(cpu_index, thrust::raw_pointer_cast(cpu_index_list[i].data()),
(shard_len + 1) * sizeof(int), cudaMemcpyDeviceToHost);
if (cpu_index[0] > 0) {
int number_on_cpu = cpu_index[0];
int64_t* cpu_keys = new int64_t[number_on_cpu];
cudaMemcpy(cpu_keys, thrust::raw_pointer_cast(cpu_keys_list[i].data()),
number_on_cpu * sizeof(int64_t), cudaMemcpyDeviceToHost);
std::vector<std::shared_ptr<char>> buffers(number_on_cpu);
std::vector<int> ac(number_on_cpu);
auto status = cpu_graph_table->random_sample_neighbors(
0, cpu_keys, sample_size, buffers, ac, false);
auto& node = path_[gpu_id][i].nodes_.back();
// display_sample_res(node.key_storage,node.val_storage,shard_len,sample_size);
int64_t* id_array = reinterpret_cast<int64_t*>(node.val_storage);
int* actual_size_array = (int*)(id_array + shard_len);
int64_t* sample_array =
(int64_t*)(actual_size_array + shard_len + shard_len % 2);
for (int j = 0; j < number_on_cpu; j++) {
int offset = cpu_index[j + 1] * sample_size;
ac[j] = ac[j] / sizeof(int64_t);
/*
std::cerr<<"for cpu key "<<cpu_keys[j]<<" ac_size = "<<ac[j];
int64_t *sss = (int64_t*)(buffers[j].get());
for(int t = 0; t < ac[j]; t++){
std::cerr<<" sampled neighbor ****** "<<sss[t];
}
std::cerr<<"index = "<<cpu_index[j+1]<<std::endl;
*/
cudaMemcpy(sample_array + offset, (int64_t*)(buffers[j].get()),
sizeof(int64_t) * ac[j], cudaMemcpyHostToDevice);
cudaMemcpy(actual_size_array + cpu_index[j + 1], ac.data() + j,
sizeof(int), cudaMemcpyHostToDevice);
// display_sample_res(node.key_storage,node.val_storage,shard_len,sample_size);
}
delete[] cpu_keys;
}
delete[] cpu_index;
}
}
move_neighbor_sample_result_to_source_gpu(gpu_id, total_gpu, sample_size,
h_left, h_right, d_shard_vals_ptr,
d_shard_actual_sample_size_ptr);
fill_dvalues<<<grid_size, block_size_, 0, stream>>>(
d_shard_vals_ptr, val, d_shard_actual_sample_size_ptr, actual_sample_size,
d_idx_ptr, sample_size, len);
cudaStreamSynchronize(stream);
if (cpu_query_switch) {
// Get cpu keys and corresponding position.
thrust::device_vector<int64_t> t_cpu_keys(len);
thrust::device_vector<int> t_index(len + 1, 0);
get_cpu_id_index<<<grid_size, block_size_, 0, stream>>>(
key, actual_sample_size, thrust::raw_pointer_cast(t_cpu_keys.data()),
thrust::raw_pointer_cast(t_index.data()),
thrust::raw_pointer_cast(t_index.data()) + 1, len);
cudaStreamSynchronize(stream);
int number_on_cpu = 0;
cudaMemcpy(&number_on_cpu, thrust::raw_pointer_cast(t_index.data()),
sizeof(int), cudaMemcpyDeviceToHost);
if (number_on_cpu > 0) {
int64_t* cpu_keys = new int64_t[number_on_cpu];
cudaMemcpy(cpu_keys, thrust::raw_pointer_cast(t_cpu_keys.data()),
number_on_cpu * sizeof(int64_t), cudaMemcpyDeviceToHost);
std::vector<std::shared_ptr<char>> buffers(number_on_cpu);
std::vector<int> ac(number_on_cpu);
auto status = cpu_graph_table->random_sample_neighbors(
0, cpu_keys, sample_size, buffers, ac, false);
int total_cpu_sample_size = std::accumulate(ac.begin(), ac.end(), 0);
total_cpu_sample_size /= sizeof(int64_t);
// Merge buffers into one int64_t vector.
int64_t* merge_buffers = new int64_t[total_cpu_sample_size];
int start = 0;
for (int j = 0; j < number_on_cpu; j++) {
memcpy(merge_buffers + start, (int64_t*)(buffers[j].get()), ac[j]);
start += ac[j] / sizeof(int64_t);
}
// Copy merge_buffers to gpu.
thrust::device_vector<int64_t> gpu_buffers(total_cpu_sample_size);
thrust::device_vector<int> gpu_ac(number_on_cpu);
int64_t* gpu_buffers_ptr = thrust::raw_pointer_cast(gpu_buffers.data());
int* gpu_ac_ptr = thrust::raw_pointer_cast(gpu_ac.data());
cudaMemcpyAsync(gpu_buffers_ptr, merge_buffers,
total_cpu_sample_size * sizeof(int64_t),
cudaMemcpyHostToDevice, stream);
cudaMemcpyAsync(gpu_ac_ptr, ac.data(), number_on_cpu * sizeof(int),
cudaMemcpyHostToDevice, stream);
// Copy gpu_buffers and gpu_ac using kernel.
// Kernel divide for gpu_ac_ptr.
int grid_size2 = (number_on_cpu - 1) / block_size_ + 1;
get_actual_gpu_ac<<<grid_size2, block_size_, 0, stream>>>(gpu_ac_ptr,
number_on_cpu);
cudaStreamSynchronize(stream);
thrust::device_vector<int> cumsum_gpu_ac(number_on_cpu);
thrust::exclusive_scan(gpu_ac.begin(), gpu_ac.end(),
cumsum_gpu_ac.begin(), 0);
constexpr int WARP_SIZE_ = 32;
constexpr int BLOCK_WARPS_ = 128 / WARP_SIZE_;
constexpr int TILE_SIZE_ = BLOCK_WARPS_ * 16;
const dim3 block2(WARP_SIZE_, BLOCK_WARPS_);
const dim3 grid2((number_on_cpu + TILE_SIZE_ - 1) / TILE_SIZE_);
copy_buffer_ac_to_final_place<WARP_SIZE_, BLOCK_WARPS_,
TILE_SIZE_><<<grid2, block2, 0, stream>>>(
gpu_buffers_ptr, gpu_ac_ptr, val, actual_sample_size,
thrust::raw_pointer_cast(t_index.data()) + 1,
thrust::raw_pointer_cast(cumsum_gpu_ac.data()), number_on_cpu,
sample_size);
delete[] merge_buffers;
delete[] cpu_keys;
}
}
{
cudaStreamSynchronize(stream);
platform::CUDAPlace place = platform::CUDAPlace(resource_->dev_id(gpu_id));
platform::CUDADeviceGuard guard(resource_->dev_id(gpu_id));
thrust::device_vector<int> t_actual_sample_size(len);
thrust::copy(actual_sample_size, actual_sample_size + len,
t_actual_sample_size.begin());
int total_sample_size = thrust::reduce(t_actual_sample_size.begin(),
t_actual_sample_size.end());
result.actual_val_mem =
memory::AllocShared(place, total_sample_size * sizeof(int64_t));
result.actual_val = (int64_t*)(result.actual_val_mem)->ptr();
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/platform/device/ipu/ipu_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ void Executor::Run(const std::vector<const Tensor *> &inputs,
}
VLOG(10) << "Prepared inputs/anchors";

if (ipu_strategy_->is_training && compiler_resources_->with_lr_sched) {
if (ipu_strategy_->is_training && compiler_resources_->with_lr_sched &&
!(ipu_strategy_->popart_options.createImplicitPipeliningFwdOnlyProgram &&
ipu_strategy_->runtime_options.enable_eval)) {
popart::Optimizer *optimizer;
if (ipu_strategy_->runtime_options.enable_eval) {
VLOG(10) << "Switch optimizer to eval mode";
Expand Down

1 comment on commit e3b8377

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.