diff --git a/paddle/phi/kernels/cpu/graph_sample_neighbors_kernel.cc b/paddle/phi/kernels/cpu/graph_sample_neighbors_kernel.cc index e18848af0dc08..b4321a85ab2ee 100644 --- a/paddle/phi/kernels/cpu/graph_sample_neighbors_kernel.cc +++ b/paddle/phi/kernels/cpu/graph_sample_neighbors_kernel.cc @@ -39,17 +39,42 @@ void SampleUniqueNeighbors( } } +template +void SampleUniqueNeighborsWithEids( + bidiiter src_begin, + bidiiter src_end, + bidiiter eid_begin, + bidiiter eid_end, + int num_samples, + std::mt19937& rng, + std::uniform_int_distribution& dice_distribution) { + int left_num = std::distance(src_begin, src_end); + for (int i = 0; i < num_samples; i++) { + bidiiter r1 = src_begin, r2 = eid_begin; + int random_step = dice_distribution(rng) % left_num; + std::advance(r1, random_step); + std::advance(r2, random_step); + std::swap(*src_begin, *r1); + std::swap(*eid_begin, *r2); + ++src_begin; + ++eid_begin; + --left_num; + } +} + template void SampleNeighbors(const T* row, const T* col_ptr, + const T* eids, const T* input, std::vector* output, std::vector* output_count, + std::vector* output_eids, int sample_size, - int bs) { - // Allocate the memory of output - // Collect the neighbors size + int bs, + bool return_eids) { std::vector> out_src_vec; + std::vector> out_eids_vec; // `sample_cumsum_sizes` record the start position and end position // after sampling. std::vector sample_cumsum_sizes(bs + 1); @@ -65,10 +90,18 @@ void SampleNeighbors(const T* row, std::vector out_src; out_src.resize(cap); out_src_vec.emplace_back(out_src); + if (return_eids) { + std::vector out_eids; + out_eids.resize(cap); + out_eids_vec.emplace_back(out_eids); + } } output_count->resize(bs); output->resize(total_neighbors); + if (return_eids) { + output_eids->resize(total_neighbors); + } std::random_device rd; std::mt19937 rng{rd()}; @@ -85,15 +118,28 @@ void SampleNeighbors(const T* row, int cap = end - begin; if (sample_size < cap) { std::copy(row + begin, row + end, out_src_vec[i].begin()); - // TODO(daisiming): Check whether is correct. - SampleUniqueNeighbors(out_src_vec[i].begin(), - out_src_vec[i].end(), - sample_size, - rng, - dice_distribution); + if (return_eids) { + std::copy(eids + begin, eids + end, out_eids_vec[i].begin()); + SampleUniqueNeighborsWithEids(out_src_vec[i].begin(), + out_src_vec[i].end(), + out_eids_vec[i].begin(), + out_eids_vec[i].end(), + sample_size, + rng, + dice_distribution); + } else { + SampleUniqueNeighbors(out_src_vec[i].begin(), + out_src_vec[i].end(), + sample_size, + rng, + dice_distribution); + } *(output_count->data() + i) = sample_size; } else { std::copy(row + begin, row + end, out_src_vec[i].begin()); + if (return_eids) { + std::copy(eids + begin, eids + end, out_eids_vec[i].begin()); + } *(output_count->data() + i) = cap; } } @@ -107,6 +153,11 @@ void SampleNeighbors(const T* row, std::copy(out_src_vec[i].begin(), out_src_vec[i].begin() + k, output->data() + sample_cumsum_sizes[i]); + if (return_eids) { + std::copy(out_eids_vec[i].begin(), + out_eids_vec[i].begin() + k, + output_eids->data() + sample_cumsum_sizes[i]); + } } } @@ -131,8 +182,35 @@ void GraphSampleNeighborsKernel( std::vector output; std::vector output_count; - SampleNeighbors( - row_data, col_ptr_data, x_data, &output, &output_count, sample_size, bs); + + if (return_eids) { + const T* eids_data = eids.get_ptr()->data(); + std::vector output_eids; + SampleNeighbors(row_data, + col_ptr_data, + eids_data, + x_data, + &output, + &output_count, + &output_eids, + sample_size, + bs, + return_eids); + out_eids->Resize({static_cast(output_eids.size())}); + T* out_eids_data = dev_ctx.template Alloc(out_eids); + std::copy(output_eids.begin(), output_eids.end(), out_eids_data); + } else { + SampleNeighbors(row_data, + col_ptr_data, + nullptr, + x_data, + &output, + &output_count, + nullptr, + sample_size, + bs, + return_eids); + } out->Resize({static_cast(output.size())}); T* out_data = dev_ctx.template Alloc(out); std::copy(output.begin(), output.end(), out_data); diff --git a/paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu b/paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu index 1757b6b98dbf9..af616963b499a 100644 --- a/paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu @@ -62,9 +62,11 @@ __global__ void SampleKernel(const uint64_t rand_seed, const T* nodes, const T* row, const T* col_ptr, + const T* eids, T* output, + T* output_eids, int* output_ptr, - int* output_idxs) { + bool return_eids) { assert(blockDim.x == WARP_SIZE); assert(blockDim.y == BLOCK_WARPS); @@ -94,10 +96,13 @@ __global__ void SampleKernel(const uint64_t rand_seed, if (deg <= k) { for (int idx = threadIdx.x; idx < deg; idx += WARP_SIZE) { output[out_row_start + idx] = row[in_row_start + idx]; + if (return_eids) { + output_eids[out_row_start + idx] = eids[in_row_start + idx]; + } } } else { for (int idx = threadIdx.x; idx < k; idx += WARP_SIZE) { - output_idxs[out_row_start + idx] = idx; + output[out_row_start + idx] = idx; } #ifdef PADDLE_WITH_CUDA __syncwarp(); @@ -111,7 +116,7 @@ __global__ void SampleKernel(const uint64_t rand_seed, #endif if (num < k) { atomicMax(reinterpret_cast( // NOLINT - output_idxs + out_row_start + num), + output + out_row_start + num), static_cast(idx)); // NOLINT } } @@ -120,8 +125,11 @@ __global__ void SampleKernel(const uint64_t rand_seed, #endif for (int idx = threadIdx.x; idx < k; idx += WARP_SIZE) { - T perm_idx = output_idxs[out_row_start + idx] + in_row_start; + T perm_idx = output[out_row_start + idx] + in_row_start; output[out_row_start + idx] = row[perm_idx]; + if (return_eids) { + output_eids[out_row_start + idx] = eids[perm_idx]; + } } } @@ -148,16 +156,17 @@ template void SampleNeighbors(const Context& dev_ctx, const T* row, const T* col_ptr, + const T* eids, const thrust::device_ptr input, thrust::device_ptr output, thrust::device_ptr output_count, + thrust::device_ptr output_eids, int sample_size, int bs, - int total_sample_num) { + int total_sample_num, + bool return_eids) { thrust::device_vector output_ptr; - thrust::device_vector output_idxs; output_ptr.resize(bs); - output_idxs.resize(total_sample_num); thrust::exclusive_scan( output_count, output_count + bs, output_ptr.begin(), 0); @@ -176,18 +185,26 @@ void SampleNeighbors(const Context& dev_ctx, thrust::raw_pointer_cast(input), row, col_ptr, + eids, thrust::raw_pointer_cast(output), + thrust::raw_pointer_cast(output_eids), thrust::raw_pointer_cast(output_ptr.data()), - thrust::raw_pointer_cast(output_idxs.data())); + return_eids); } -template +template __global__ void FisherYatesSampleKernel(const uint64_t rand_seed, int k, const int64_t num_rows, const T* in_rows, T* src, const T* dst_count) { + assert(blockDim.x == WARP_SIZE); + assert(blockDim.y == BLOCK_WARPS); + + int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y; + const int64_t last_row = + min(static_cast(blockIdx.x + 1) * TILE_SIZE, num_rows); #ifdef PADDLE_WITH_HIP hiprandState rng; hiprand_init( @@ -197,20 +214,19 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed, curand_init( rand_seed * gridDim.x + blockIdx.x, threadIdx.y + threadIdx.x, 0, &rng); #endif - CUDA_KERNEL_LOOP(out_row, num_rows) { + + while (out_row < last_row) { const T row = in_rows[out_row]; const T in_row_start = dst_count[row]; const int deg = dst_count[row + 1] - in_row_start; int split; - T tmp; - if (k < deg) { if (deg < 2 * k) { split = k; } else { split = deg - k; } - for (int idx = deg - 1; idx >= split; idx--) { + for (int idx = split + threadIdx.x; idx <= deg - 1; idx += WARP_SIZE) { #ifdef PADDLE_WITH_HIP const int num = hiprand(&rng) % (idx + 1); #else @@ -222,7 +238,11 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed, static_cast( // NOLINT src[in_row_start + idx]))); } +#ifdef PADDLE_WITH_CUDA + __syncwarp(); +#endif } + out_row += BLOCK_WARPS; } } @@ -232,9 +252,12 @@ __global__ void GatherEdge(int k, const T* in_rows, const T* src, const T* dst_count, + const T* eids, T* outputs, + T* output_eids, int* output_ptr, - T* perm_data) { + T* perm_data, + bool return_eids) { assert(blockDim.x == WARP_SIZE); assert(blockDim.y == BLOCK_WARPS); @@ -250,8 +273,10 @@ __global__ void GatherEdge(int k, if (deg <= k) { for (int idx = threadIdx.x; idx < deg; idx += WARP_SIZE) { - const T in_idx = in_row_start + idx; - outputs[out_row_start + idx] = src[in_idx]; + outputs[out_row_start + idx] = src[in_row_start + idx]; + if (return_eids) { + output_eids[out_row_start + idx] = eids[in_row_start + idx]; + } } } else { int split = k; @@ -267,6 +292,10 @@ __global__ void GatherEdge(int k, for (int idx = begin + threadIdx.x; idx < end; idx += WARP_SIZE) { outputs[out_row_start + idx - begin] = src[perm_data[in_row_start + idx]]; + if (return_eids) { + output_eids[out_row_start + idx - begin] = + eids[perm_data[in_row_start + idx]]; + } } } out_row += BLOCK_WARPS; @@ -277,49 +306,48 @@ template void FisherYatesSampleNeighbors(const Context& dev_ctx, const T* row, const T* col_ptr, + const T* eids, T* perm_data, const thrust::device_ptr input, thrust::device_ptr output, thrust::device_ptr output_count, + thrust::device_ptr output_eids, int sample_size, int bs, - int total_sample_num) { + int total_sample_num, + bool return_eids) { thrust::device_vector output_ptr; output_ptr.resize(bs); thrust::exclusive_scan( output_count, output_count + bs, output_ptr.begin(), 0); -#ifdef PADDLE_WITH_HIP - int block = 256; -#else - int block = 1024; -#endif - int max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize()[0]; - int grid_tmp = (bs + block - 1) / block; - int grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx; + constexpr int WARP_SIZE = 32; + constexpr int BLOCK_WARPS = 128 / WARP_SIZE; + constexpr int TILE_SIZE = BLOCK_WARPS * 16; + const dim3 block(WARP_SIZE, BLOCK_WARPS); + const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE); - FisherYatesSampleKernel<<>>( + FisherYatesSampleKernel<<>>( 0, sample_size, bs, thrust::raw_pointer_cast(input), perm_data, col_ptr); - constexpr int GATHER_WARP_SIZE = 32; - constexpr int GATHER_BLOCK_WARPS = 128 / GATHER_WARP_SIZE; - constexpr int GATHER_TILE_SIZE = GATHER_BLOCK_WARPS * 16; - const dim3 gather_block(GATHER_WARP_SIZE, GATHER_BLOCK_WARPS); - const dim3 gather_grid((bs + GATHER_TILE_SIZE - 1) / GATHER_TILE_SIZE); - - GatherEdge< - T, - GATHER_WARP_SIZE, - GATHER_BLOCK_WARPS, - GATHER_TILE_SIZE><<>>( + GatherEdge<<>>( sample_size, bs, thrust::raw_pointer_cast(input), row, col_ptr, + eids, thrust::raw_pointer_cast(output), + thrust::raw_pointer_cast(output_eids), thrust::raw_pointer_cast(output_ptr.data()), - perm_data); + perm_data, + return_eids); } template @@ -354,32 +382,78 @@ void GraphSampleNeighborsKernel( T* out_data = dev_ctx.template Alloc(out); thrust::device_ptr output(out_data); - if (!flag_perm_buffer) { - SampleNeighbors(dev_ctx, - row_data, - col_ptr_data, - input, - output, - output_count, - sample_size, - bs, - total_sample_size); + if (return_eids) { + auto* eids_data = eids.get_ptr()->data(); + out_eids->Resize({static_cast(total_sample_size)}); + T* out_eids_data = dev_ctx.template Alloc(out_eids); + thrust::device_ptr output_eids(out_eids_data); + if (!flag_perm_buffer) { + SampleNeighbors(dev_ctx, + row_data, + col_ptr_data, + eids_data, + input, + output, + output_count, + output_eids, + sample_size, + bs, + total_sample_size, + return_eids); + } else { + DenseTensor perm_buffer_out(perm_buffer->type()); + const auto* p_perm_buffer = perm_buffer.get_ptr(); + perm_buffer_out.ShareDataWith(*p_perm_buffer); + T* perm_buffer_out_data = perm_buffer_out.template data(); + FisherYatesSampleNeighbors(dev_ctx, + row_data, + col_ptr_data, + eids_data, + perm_buffer_out_data, + input, + output, + output_count, + output_eids, + sample_size, + bs, + total_sample_size, + return_eids); + } } else { - DenseTensor perm_buffer_out(perm_buffer->type()); - const auto* p_perm_buffer = perm_buffer.get_ptr(); - perm_buffer_out.ShareDataWith(*p_perm_buffer); - T* perm_buffer_out_data = - perm_buffer_out.mutable_data(dev_ctx.GetPlace()); - FisherYatesSampleNeighbors(dev_ctx, - row_data, - col_ptr_data, - perm_buffer_out_data, - input, - output, - output_count, - sample_size, - bs, - total_sample_size); + // How to set null value for output_eids(thrust::device_ptr)? + // We use `output` to fill the position of unused output_eids. + if (!flag_perm_buffer) { + SampleNeighbors(dev_ctx, + row_data, + col_ptr_data, + nullptr, + input, + output, + output_count, + output, + sample_size, + bs, + total_sample_size, + return_eids); + } else { + DenseTensor perm_buffer_out(perm_buffer->type()); + const auto* p_perm_buffer = perm_buffer.get_ptr(); + perm_buffer_out.ShareDataWith(*p_perm_buffer); + T* perm_buffer_out_data = perm_buffer_out.template data(); + FisherYatesSampleNeighbors(dev_ctx, + row_data, + col_ptr_data, + nullptr, + perm_buffer_out_data, + input, + output, + output_count, + output, + sample_size, + bs, + total_sample_size, + return_eids); + } } } diff --git a/python/paddle/fluid/tests/unittests/test_graph_sample_neighbors.py b/python/paddle/fluid/tests/unittests/test_graph_sample_neighbors.py index d2fbeab3fd42c..675a3429ab55f 100644 --- a/python/paddle/fluid/tests/unittests/test_graph_sample_neighbors.py +++ b/python/paddle/fluid/tests/unittests/test_graph_sample_neighbors.py @@ -162,14 +162,14 @@ def check_perm_buffer_error(): self.assertRaises(ValueError, check_perm_buffer_error) def test_sample_result_with_eids(self): - # Note: Currently return eid results is not initialized. paddle.disable_static() row = paddle.to_tensor(self.row) colptr = paddle.to_tensor(self.colptr) nodes = paddle.to_tensor(self.nodes) eids = paddle.to_tensor(self.edges_id) + perm_buffer = paddle.to_tensor(self.edges_id) - out_neighbors, out_count, _ = paddle.incubate.graph_sample_neighbors( + out_neighbors, out_count, out_eids = paddle.incubate.graph_sample_neighbors( row, colptr, nodes, @@ -177,6 +177,16 @@ def test_sample_result_with_eids(self): sample_size=self.sample_size, return_eids=True) + out_neighbors, out_count, out_eids = paddle.incubate.graph_sample_neighbors( + row, + colptr, + nodes, + eids=eids, + perm_buffer=perm_buffer, + sample_size=self.sample_size, + return_eids=True, + flag_perm_buffer=True) + paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): row = paddle.static.data( @@ -188,7 +198,7 @@ def test_sample_result_with_eids(self): eids = paddle.static.data( name="eids", shape=self.edges_id.shape, dtype=self.nodes.dtype) - out_neighbors, out_count, _ = paddle.incubate.graph_sample_neighbors( + out_neighbors, out_count, out_eids = paddle.incubate.graph_sample_neighbors( row, colptr, nodes, @@ -202,7 +212,7 @@ def test_sample_result_with_eids(self): 'nodes': self.nodes, 'eids': self.edges_id }, - fetch_list=[out_neighbors, out_count]) + fetch_list=[out_neighbors, out_count, out_eids]) if __name__ == "__main__":