diff --git a/paddle/fluid/distributed/service/graph_brpc_client.cc b/paddle/fluid/distributed/service/graph_brpc_client.cc index 2c9ee2b9037b7..1b54ce986d9c2 100644 --- a/paddle/fluid/distributed/service/graph_brpc_client.cc +++ b/paddle/fluid/distributed/service/graph_brpc_client.cc @@ -173,7 +173,7 @@ std::future GraphBrpcClient::random_sample_nodes( return fut; } std::future GraphBrpcClient::pull_graph_list( - uint32_t table_id, int server_index, int start, int size, + uint32_t table_id, int server_index, int start, int size, int step, std::vector &res) { DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) { int ret = 0; @@ -207,6 +207,7 @@ std::future GraphBrpcClient::pull_graph_list( closure->request(0)->set_client_id(_client_id); closure->request(0)->add_params((char *)&start, sizeof(int)); closure->request(0)->add_params((char *)&size, sizeof(int)); + closure->request(0)->add_params((char *)&step, sizeof(int)); PsService_Stub rpc_stub(get_cmd_channel(server_index)); closure->cntl(0)->set_log_id(butil::gettimeofday_ms()); rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), diff --git a/paddle/fluid/distributed/service/graph_brpc_client.h b/paddle/fluid/distributed/service/graph_brpc_client.h index 8f14e646050c3..e762409fbeeda 100644 --- a/paddle/fluid/distributed/service/graph_brpc_client.h +++ b/paddle/fluid/distributed/service/graph_brpc_client.h @@ -41,7 +41,7 @@ class GraphBrpcClient : public BrpcPsClient { std::vector>> &res); virtual std::future pull_graph_list(uint32_t table_id, int server_index, int start, - int size, + int size, int step, std::vector &res); virtual std::future random_sample_nodes(uint32_t table_id, int server_index, diff --git a/paddle/fluid/distributed/service/graph_brpc_server.cc b/paddle/fluid/distributed/service/graph_brpc_server.cc index cd1c2330a7b85..60d6bc203a074 100644 --- a/paddle/fluid/distributed/service/graph_brpc_server.cc +++ b/paddle/fluid/distributed/service/graph_brpc_server.cc @@ -260,16 +260,17 @@ int32_t GraphBrpcService::pull_graph_list(Table *table, PsResponseMessage &response, brpc::Controller *cntl) { CHECK_TABLE_EXIST(table, request, response) - if (request.params_size() < 2) { + if (request.params_size() < 3) { set_response_code(response, -1, - "pull_graph_list request requires at least 2 arguments"); + "pull_graph_list request requires at least 3 arguments"); return 0; } int start = *(int *)(request.params(0).c_str()); int size = *(int *)(request.params(1).c_str()); + int step = *(int *)(request.params(2).c_str()); std::unique_ptr buffer; int actual_size; - table->pull_graph_list(start, size, buffer, actual_size, true); + table->pull_graph_list(start, size, buffer, actual_size, step, true); cntl->response_attachment().append(buffer.get(), actual_size); return 0; } diff --git a/paddle/fluid/distributed/service/graph_py_service.cc b/paddle/fluid/distributed/service/graph_py_service.cc index d37eb289f66ea..752dd0bcba2f0 100644 --- a/paddle/fluid/distributed/service/graph_py_service.cc +++ b/paddle/fluid/distributed/service/graph_py_service.cc @@ -219,12 +219,13 @@ std::vector GraphPyClient::random_sample_nodes(std::string name, } std::vector GraphPyClient::pull_graph_list(std::string name, int server_index, - int start, int size) { + int start, int size, + int step) { std::vector res; if (this->table_id_map.count(name)) { uint32_t table_id = this->table_id_map[name]; - auto status = - worker_ptr->pull_graph_list(table_id, server_index, start, size, res); + auto status = worker_ptr->pull_graph_list(table_id, server_index, start, + size, step, res); status.wait(); } return res; diff --git a/paddle/fluid/distributed/service/graph_py_service.h b/paddle/fluid/distributed/service/graph_py_service.h index 9283618dc9533..7452c99c07f60 100644 --- a/paddle/fluid/distributed/service/graph_py_service.h +++ b/paddle/fluid/distributed/service/graph_py_service.h @@ -127,7 +127,7 @@ class GraphPyClient : public GraphPyService { std::vector random_sample_nodes(std::string name, int server_index, int sample_size); std::vector pull_graph_list(std::string name, int server_index, - int start, int size); + int start, int size, int step = 1); ::paddle::distributed::PSParameter GetWorkerProto(); protected: diff --git a/paddle/fluid/distributed/service/ps_client.h b/paddle/fluid/distributed/service/ps_client.h index d3274b4a57334..1b61e35641322 100644 --- a/paddle/fluid/distributed/service/ps_client.h +++ b/paddle/fluid/distributed/service/ps_client.h @@ -166,7 +166,7 @@ class PSClient { } virtual std::future pull_graph_list(uint32_t table_id, int server_index, int start, - int size, + int size, int step, std::vector &res) { LOG(FATAL) << "Did not implement"; std::promise promise; diff --git a/paddle/fluid/distributed/table/common_graph_table.cc b/paddle/fluid/distributed/table/common_graph_table.cc index d2c5a095a305e..2e8cf6593c253 100644 --- a/paddle/fluid/distributed/table/common_graph_table.cc +++ b/paddle/fluid/distributed/table/common_graph_table.cc @@ -23,10 +23,10 @@ namespace paddle { namespace distributed { -std::vector GraphShard::get_batch(int start, int total_size) { +std::vector GraphShard::get_batch(int start, int end, int step) { if (start < 0) start = 0; std::vector res; - for (int pos = start; pos < start + total_size; pos++) { + for (int pos = start; pos < std::min(end, (int)bucket.size()); pos += step) { res.push_back(bucket[pos]); } return res; @@ -52,15 +52,14 @@ GraphNode *GraphShard::find_node(uint64_t id) { } int32_t GraphTable::load(const std::string &path, const std::string ¶m) { - bool load_edge = (param[0] == 'e'); bool load_node = (param[0] == 'n'); if (load_edge) { bool reverse_edge = (param[1] == '<'); return this->load_edges(path, reverse_edge); } - if (load_node){ - std::string node_type = param.substr(1); + if (load_node) { + std::string node_type = param.substr(1); return this->load_nodes(path, node_type); } } @@ -125,18 +124,17 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) { std::string nt = values[0]; if (nt != node_type) { - continue; + continue; } std::vector feature; for (size_t slice = 2; slice < values.size(); slice++) { feature.push_back(values[slice]); } size_t index = shard_id - shard_start; - if(feature.size() > 0) { - shards[index].add_node(id, paddle::string::join_strings(feature, '\t')); - } - else { - shards[index].add_node(id, std::string("")); + if (feature.size() > 0) { + shards[index].add_node(id, paddle::string::join_strings(feature, '\t')); + } else { + shards[index].add_node(id, std::string("")); } } } @@ -188,7 +186,8 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) { for (auto &shard : shards) { auto bucket = shard.get_bucket(); for (int i = 0; i < bucket.size(); i++) { - bucket[i]->build_sampler(sample_type); } + bucket[i]->build_sampler(sample_type); + } } return 0; } @@ -315,37 +314,27 @@ int GraphTable::random_sample_neighboors( } int32_t GraphTable::pull_graph_list(int start, int total_size, std::unique_ptr &buffer, - int &actual_size, bool need_feature) { + int &actual_size, bool need_feature, + int step) { if (start < 0) start = 0; int size = 0, cur_size; - if (total_size <= 0) { - actual_size = 0; - return 0; - } std::vector>> tasks; - for (size_t i = 0; i < shards.size(); i++) { + for (size_t i = 0; i < shards.size() && total_size > 0; i++) { cur_size = shards[i].get_size(); if (size + cur_size <= start) { size += cur_size; continue; } - if (size + cur_size - start >= total_size) { - tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue( - [this, i, start, size, total_size]() -> std::vector { - return this->shards[i].get_batch(start - size, total_size); - })); - break; - } else { - tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue( - [this, i, start, size, total_size, - cur_size]() -> std::vector { - return this->shards[i].get_batch(start - size, - size + cur_size - start); - })); - total_size -= size + cur_size - start; - size += cur_size; - start = size; - } + int count = std::min(1 + (size + cur_size - start - 1) / step, total_size); + int end = start + (count - 1) * step + 1; + tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue( + [this, i, start, end, step, size]() -> std::vector { + + return this->shards[i].get_batch(start - size, end - size, step); + })); + start += count * step; + total_size -= count; + size += cur_size; } for (size_t i = 0; i < tasks.size(); ++i) { tasks[i].wait(); diff --git a/paddle/fluid/distributed/table/common_graph_table.h b/paddle/fluid/distributed/table/common_graph_table.h index 106dcaaac2b00..d298dc963800e 100644 --- a/paddle/fluid/distributed/table/common_graph_table.h +++ b/paddle/fluid/distributed/table/common_graph_table.h @@ -45,7 +45,7 @@ class GraphShard { // bucket.resize(bucket_size); } std::vector &get_bucket() { return bucket; } - std::vector get_batch(int start, int total_size); + std::vector get_batch(int start, int end, int step); // int init_bucket_size(int shard_num) { // for (int i = bucket_low_bound;; i++) { // if (gcd(i, shard_num) == 1) return i; @@ -78,7 +78,8 @@ class GraphTable : public SparseTable { virtual ~GraphTable() {} virtual int32_t pull_graph_list(int start, int size, std::unique_ptr &buffer, - int &actual_size, bool need_feature); + int &actual_size, bool need_feature, + int step); virtual int32_t random_sample_neighboors( uint64_t *node_ids, int sample_size, diff --git a/paddle/fluid/distributed/table/graph_edge.h b/paddle/fluid/distributed/table/graph_edge.h index 2a838f0b763e3..3dfe5a6f357a7 100644 --- a/paddle/fluid/distributed/table/graph_edge.h +++ b/paddle/fluid/distributed/table/graph_edge.h @@ -13,33 +13,34 @@ // limitations under the License. #pragma once -#include +#include #include +#include namespace paddle { namespace distributed { - class GraphEdgeBlob { -public: + public: GraphEdgeBlob() {} virtual ~GraphEdgeBlob() {} - const size_t size() {return id_arr.size();} + size_t size() { return id_arr.size(); } virtual void add_edge(uint64_t id, float weight); - const uint64_t get_id(int idx) { return id_arr[idx]; } - virtual const float get_weight(int idx) { return 1; } -protected: + uint64_t get_id(int idx) { return id_arr[idx]; } + virtual float get_weight(int idx) { return 1; } + + protected: std::vector id_arr; }; -class WeightedGraphEdgeBlob: public GraphEdgeBlob{ -public: +class WeightedGraphEdgeBlob : public GraphEdgeBlob { + public: WeightedGraphEdgeBlob() {} virtual ~WeightedGraphEdgeBlob() {} virtual void add_edge(uint64_t id, float weight); - virtual const float get_weight(int idx) { return weight_arr[idx]; } -protected: + virtual float get_weight(int idx) { return weight_arr[idx]; } + + protected: std::vector weight_arr; }; - } } diff --git a/paddle/fluid/distributed/table/table.h b/paddle/fluid/distributed/table/table.h index 65a3191c50255..40d2abb6c86ae 100644 --- a/paddle/fluid/distributed/table/table.h +++ b/paddle/fluid/distributed/table/table.h @@ -90,7 +90,8 @@ class Table { // only for graph table virtual int32_t pull_graph_list(int start, int total_size, std::unique_ptr &buffer, - int &actual_size, bool need_feature) { + int &actual_size, bool need_feature, + int step = 1) { return 0; } // only for graph table diff --git a/paddle/fluid/distributed/test/graph_node_test.cc b/paddle/fluid/distributed/test/graph_node_test.cc index e10630b016600..efee0d9441ef2 100644 --- a/paddle/fluid/distributed/test/graph_node_test.cc +++ b/paddle/fluid/distributed/test/graph_node_test.cc @@ -335,12 +335,12 @@ void RunBrpcPushSparse() { ASSERT_EQ(0, vs[0].size()); std::vector nodes; - pull_status = worker_ptr_->pull_graph_list(0, 0, 0, 1, nodes); + pull_status = worker_ptr_->pull_graph_list(0, 0, 0, 1, 1, nodes); pull_status.wait(); ASSERT_EQ(nodes.size(), 1); ASSERT_EQ(nodes[0].get_id(), 37); nodes.clear(); - pull_status = worker_ptr_->pull_graph_list(0, 0, 1, 4, nodes); + pull_status = worker_ptr_->pull_graph_list(0, 0, 1, 4, 1, nodes); pull_status.wait(); ASSERT_EQ(nodes.size(), 1); ASSERT_EQ(nodes[0].get_id(), 59); @@ -373,7 +373,7 @@ void RunBrpcPushSparse() { // client2.load_edge_file(std::string("user2item"), std::string(file_name), // 0); nodes.clear(); - nodes = client1.pull_graph_list(std::string("user"), 0, 1, 4); + nodes = client1.pull_graph_list(std::string("user"), 0, 1, 4, 1); for (auto g : nodes) { std::cout << "node_ids: " << g.get_id() << std::endl;