Skip to content

Commit

Permalink
Merge pull request #10 from seemingwang/develop
Browse files Browse the repository at this point in the history
pull nodes with step
  • Loading branch information
Yelrose authored Mar 22, 2021
2 parents 009cc03 + 09667d1 commit 67aabdb
Show file tree
Hide file tree
Showing 11 changed files with 58 additions and 63 deletions.
3 changes: 2 additions & 1 deletion paddle/fluid/distributed/service/graph_brpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes(
return fut;
}
std::future<int32_t> GraphBrpcClient::pull_graph_list(
uint32_t table_id, int server_index, int start, int size,
uint32_t table_id, int server_index, int start, int size, int step,
std::vector<GraphNode> &res) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
int ret = 0;
Expand Down Expand Up @@ -207,6 +207,7 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
closure->request(0)->set_client_id(_client_id);
closure->request(0)->add_params((char *)&start, sizeof(int));
closure->request(0)->add_params((char *)&size, sizeof(int));
closure->request(0)->add_params((char *)&step, sizeof(int));
PsService_Stub rpc_stub(get_cmd_channel(server_index));
closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0),
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/service/graph_brpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class GraphBrpcClient : public BrpcPsClient {
std::vector<std::vector<std::pair<uint64_t, float>>> &res);
virtual std::future<int32_t> pull_graph_list(uint32_t table_id,
int server_index, int start,
int size,
int size, int step,
std::vector<GraphNode> &res);
virtual std::future<int32_t> random_sample_nodes(uint32_t table_id,
int server_index,
Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/distributed/service/graph_brpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -260,16 +260,17 @@ int32_t GraphBrpcService::pull_graph_list(Table *table,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
if (request.params_size() < 3) {
set_response_code(response, -1,
"pull_graph_list request requires at least 2 arguments");
"pull_graph_list request requires at least 3 arguments");
return 0;
}
int start = *(int *)(request.params(0).c_str());
int size = *(int *)(request.params(1).c_str());
int step = *(int *)(request.params(2).c_str());
std::unique_ptr<char[]> buffer;
int actual_size;
table->pull_graph_list(start, size, buffer, actual_size, true);
table->pull_graph_list(start, size, buffer, actual_size, step, true);
cntl->response_attachment().append(buffer.get(), actual_size);
return 0;
}
Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/distributed/service/graph_py_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,13 @@ std::vector<uint64_t> GraphPyClient::random_sample_nodes(std::string name,
}
std::vector<GraphNode> GraphPyClient::pull_graph_list(std::string name,
int server_index,
int start, int size) {
int start, int size,
int step) {
std::vector<GraphNode> res;
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
auto status =
worker_ptr->pull_graph_list(table_id, server_index, start, size, res);
auto status = worker_ptr->pull_graph_list(table_id, server_index, start,
size, step, res);
status.wait();
}
return res;
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/service/graph_py_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class GraphPyClient : public GraphPyService {
std::vector<uint64_t> random_sample_nodes(std::string name, int server_index,
int sample_size);
std::vector<GraphNode> pull_graph_list(std::string name, int server_index,
int start, int size);
int start, int size, int step = 1);
::paddle::distributed::PSParameter GetWorkerProto();

protected:
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/service/ps_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class PSClient {
}
virtual std::future<int32_t> pull_graph_list(uint32_t table_id,
int server_index, int start,
int size,
int size, int step,
std::vector<GraphNode> &res) {
LOG(FATAL) << "Did not implement";
std::promise<int32_t> promise;
Expand Down
59 changes: 24 additions & 35 deletions paddle/fluid/distributed/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
namespace paddle {
namespace distributed {

std::vector<GraphNode *> GraphShard::get_batch(int start, int total_size) {
std::vector<GraphNode *> GraphShard::get_batch(int start, int end, int step) {
if (start < 0) start = 0;
std::vector<GraphNode *> res;
for (int pos = start; pos < start + total_size; pos++) {
for (int pos = start; pos < std::min(end, (int)bucket.size()); pos += step) {
res.push_back(bucket[pos]);
}
return res;
Expand All @@ -52,15 +52,14 @@ GraphNode *GraphShard::find_node(uint64_t id) {
}

int32_t GraphTable::load(const std::string &path, const std::string &param) {

bool load_edge = (param[0] == 'e');
bool load_node = (param[0] == 'n');
if (load_edge) {
bool reverse_edge = (param[1] == '<');
return this->load_edges(path, reverse_edge);
}
if (load_node){
std::string node_type = param.substr(1);
if (load_node) {
std::string node_type = param.substr(1);
return this->load_nodes(path, node_type);
}
}
Expand Down Expand Up @@ -125,18 +124,17 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {

std::string nt = values[0];
if (nt != node_type) {
continue;
continue;
}
std::vector<std::string> feature;
for (size_t slice = 2; slice < values.size(); slice++) {
feature.push_back(values[slice]);
}
size_t index = shard_id - shard_start;
if(feature.size() > 0) {
shards[index].add_node(id, paddle::string::join_strings(feature, '\t'));
}
else {
shards[index].add_node(id, std::string(""));
if (feature.size() > 0) {
shards[index].add_node(id, paddle::string::join_strings(feature, '\t'));
} else {
shards[index].add_node(id, std::string(""));
}
}
}
Expand Down Expand Up @@ -188,7 +186,8 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
for (auto &shard : shards) {
auto bucket = shard.get_bucket();
for (int i = 0; i < bucket.size(); i++) {
bucket[i]->build_sampler(sample_type); }
bucket[i]->build_sampler(sample_type);
}
}
return 0;
}
Expand Down Expand Up @@ -315,37 +314,27 @@ int GraphTable::random_sample_neighboors(
}
int32_t GraphTable::pull_graph_list(int start, int total_size,
std::unique_ptr<char[]> &buffer,
int &actual_size, bool need_feature) {
int &actual_size, bool need_feature,
int step) {
if (start < 0) start = 0;
int size = 0, cur_size;
if (total_size <= 0) {
actual_size = 0;
return 0;
}
std::vector<std::future<std::vector<GraphNode *>>> tasks;
for (size_t i = 0; i < shards.size(); i++) {
for (size_t i = 0; i < shards.size() && total_size > 0; i++) {
cur_size = shards[i].get_size();
if (size + cur_size <= start) {
size += cur_size;
continue;
}
if (size + cur_size - start >= total_size) {
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[this, i, start, size, total_size]() -> std::vector<GraphNode *> {
return this->shards[i].get_batch(start - size, total_size);
}));
break;
} else {
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[this, i, start, size, total_size,
cur_size]() -> std::vector<GraphNode *> {
return this->shards[i].get_batch(start - size,
size + cur_size - start);
}));
total_size -= size + cur_size - start;
size += cur_size;
start = size;
}
int count = std::min(1 + (size + cur_size - start - 1) / step, total_size);
int end = start + (count - 1) * step + 1;
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[this, i, start, end, step, size]() -> std::vector<GraphNode *> {

return this->shards[i].get_batch(start - size, end - size, step);
}));
start += count * step;
total_size -= count;
size += cur_size;
}
for (size_t i = 0; i < tasks.size(); ++i) {
tasks[i].wait();
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/distributed/table/common_graph_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class GraphShard {
// bucket.resize(bucket_size);
}
std::vector<GraphNode *> &get_bucket() { return bucket; }
std::vector<GraphNode *> get_batch(int start, int total_size);
std::vector<GraphNode *> get_batch(int start, int end, int step);
// int init_bucket_size(int shard_num) {
// for (int i = bucket_low_bound;; i++) {
// if (gcd(i, shard_num) == 1) return i;
Expand Down Expand Up @@ -78,7 +78,8 @@ class GraphTable : public SparseTable {
virtual ~GraphTable() {}
virtual int32_t pull_graph_list(int start, int size,
std::unique_ptr<char[]> &buffer,
int &actual_size, bool need_feature);
int &actual_size, bool need_feature,
int step);

virtual int32_t random_sample_neighboors(
uint64_t *node_ids, int sample_size,
Expand Down
25 changes: 13 additions & 12 deletions paddle/fluid/distributed/table/graph_edge.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,34 @@
// limitations under the License.

#pragma once
#include <vector>
#include <cstddef>
#include <cstdint>
#include <vector>
namespace paddle {
namespace distributed {


class GraphEdgeBlob {
public:
public:
GraphEdgeBlob() {}
virtual ~GraphEdgeBlob() {}
const size_t size() {return id_arr.size();}
size_t size() { return id_arr.size(); }
virtual void add_edge(uint64_t id, float weight);
const uint64_t get_id(int idx) { return id_arr[idx]; }
virtual const float get_weight(int idx) { return 1; }
protected:
uint64_t get_id(int idx) { return id_arr[idx]; }
virtual float get_weight(int idx) { return 1; }

protected:
std::vector<uint64_t> id_arr;
};

class WeightedGraphEdgeBlob: public GraphEdgeBlob{
public:
class WeightedGraphEdgeBlob : public GraphEdgeBlob {
public:
WeightedGraphEdgeBlob() {}
virtual ~WeightedGraphEdgeBlob() {}
virtual void add_edge(uint64_t id, float weight);
virtual const float get_weight(int idx) { return weight_arr[idx]; }
protected:
virtual float get_weight(int idx) { return weight_arr[idx]; }

protected:
std::vector<float> weight_arr;
};

}
}
3 changes: 2 additions & 1 deletion paddle/fluid/distributed/table/table.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ class Table {
// only for graph table
virtual int32_t pull_graph_list(int start, int total_size,
std::unique_ptr<char[]> &buffer,
int &actual_size, bool need_feature) {
int &actual_size, bool need_feature,
int step = 1) {
return 0;
}
// only for graph table
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/distributed/test/graph_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -335,12 +335,12 @@ void RunBrpcPushSparse() {
ASSERT_EQ(0, vs[0].size());

std::vector<distributed::GraphNode> nodes;
pull_status = worker_ptr_->pull_graph_list(0, 0, 0, 1, nodes);
pull_status = worker_ptr_->pull_graph_list(0, 0, 0, 1, 1, nodes);
pull_status.wait();
ASSERT_EQ(nodes.size(), 1);
ASSERT_EQ(nodes[0].get_id(), 37);
nodes.clear();
pull_status = worker_ptr_->pull_graph_list(0, 0, 1, 4, nodes);
pull_status = worker_ptr_->pull_graph_list(0, 0, 1, 4, 1, nodes);
pull_status.wait();
ASSERT_EQ(nodes.size(), 1);
ASSERT_EQ(nodes[0].get_id(), 59);
Expand Down Expand Up @@ -373,7 +373,7 @@ void RunBrpcPushSparse() {
// client2.load_edge_file(std::string("user2item"), std::string(file_name),
// 0);
nodes.clear();
nodes = client1.pull_graph_list(std::string("user"), 0, 1, 4);
nodes = client1.pull_graph_list(std::string("user"), 0, 1, 4, 1);

for (auto g : nodes) {
std::cout << "node_ids: " << g.get_id() << std::endl;
Expand Down

0 comments on commit 67aabdb

Please sign in to comment.