Skip to content

Commit

Permalink
Update unweighted sample (PaddlePaddle#240)
Browse files Browse the repository at this point in the history
* add unweighted sample from wholegraph

* optimize kernel

* comment out unused code
  • Loading branch information
DesmonDay authored Mar 28, 2023
1 parent cf361ae commit dbbded6
Show file tree
Hide file tree
Showing 2 changed files with 283 additions and 30 deletions.
12 changes: 11 additions & 1 deletion paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,19 @@ class GpuPsGraphTable
int shard_len,
bool need_neighbor_count,
unsigned long long random_seed,
int default_value,
float* weight_array,
bool return_weight);
void unweighted_sample(GpuPsCommGraph& graph,
GpuPsNodeInfo* node_info_list,
int* actual_size_array,
uint64_t* sample_array,
int cur_gpu_id,
int remote_gpu_id,
int sample_size,
int shard_len,
unsigned long long random_seed,
float* weight_array,
bool return_weight);
std::vector<std::shared_ptr<phi::Allocation>> get_edge_type_graph(
int gpu_id, int edge_type_len);
void get_node_degree(int gpu_id,
Expand Down
Loading

0 comments on commit dbbded6

Please sign in to comment.