diff --git a/paddle/fluid/distributed/service/graph_brpc_client.cc b/paddle/fluid/distributed/service/graph_brpc_client.cc index 1b54ce986d9c2..47bd70ff5d980 100644 --- a/paddle/fluid/distributed/service/graph_brpc_client.cc +++ b/paddle/fluid/distributed/service/graph_brpc_client.cc @@ -174,7 +174,7 @@ std::future GraphBrpcClient::random_sample_nodes( } std::future GraphBrpcClient::pull_graph_list( uint32_t table_id, int server_index, int start, int size, int step, - std::vector &res) { + std::vector &res) { DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) { int ret = 0; auto *closure = (DownpourBrpcClosure *)done; @@ -190,9 +190,9 @@ std::future GraphBrpcClient::pull_graph_list( io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size); int index = 0; while (index < bytes_size) { - GraphNode node; + FeatureNode node; node.recover_from_buffer(buffer + index); - index += node.get_size(true); + index += node.get_size(false); res.push_back(node); } } diff --git a/paddle/fluid/distributed/service/graph_brpc_client.h b/paddle/fluid/distributed/service/graph_brpc_client.h index e762409fbeeda..87e481d578a84 100644 --- a/paddle/fluid/distributed/service/graph_brpc_client.h +++ b/paddle/fluid/distributed/service/graph_brpc_client.h @@ -42,7 +42,7 @@ class GraphBrpcClient : public BrpcPsClient { virtual std::future pull_graph_list(uint32_t table_id, int server_index, int start, int size, int step, - std::vector &res); + std::vector &res); virtual std::future random_sample_nodes(uint32_t table_id, int server_index, int sample_size, diff --git a/paddle/fluid/distributed/service/graph_brpc_server.cc b/paddle/fluid/distributed/service/graph_brpc_server.cc index 765c4e9254254..2f619fdb2a661 100644 --- a/paddle/fluid/distributed/service/graph_brpc_server.cc +++ b/paddle/fluid/distributed/service/graph_brpc_server.cc @@ -270,7 +270,7 @@ int32_t GraphBrpcService::pull_graph_list(Table *table, int step = *(int *)(request.params(2).c_str()); std::unique_ptr buffer; int actual_size; - table->pull_graph_list(start, size, buffer, actual_size, true, step); + table->pull_graph_list(start, size, buffer, actual_size, false, step); cntl->response_attachment().append(buffer.get(), actual_size); return 0; } diff --git a/paddle/fluid/distributed/service/graph_py_service.cc b/paddle/fluid/distributed/service/graph_py_service.cc index 2fb10bff7df51..84ceb102a1e0f 100644 --- a/paddle/fluid/distributed/service/graph_py_service.cc +++ b/paddle/fluid/distributed/service/graph_py_service.cc @@ -310,11 +310,11 @@ std::vector GraphPyClient::random_sample_nodes(std::string name, } return v; } -std::vector GraphPyClient::pull_graph_list(std::string name, +std::vector GraphPyClient::pull_graph_list(std::string name, int server_index, int start, int size, int step) { - std::vector res; + std::vector res; if (this->table_id_map.count(name)) { uint32_t table_id = this->table_id_map[name]; auto status = worker_ptr->pull_graph_list(table_id, server_index, start, diff --git a/paddle/fluid/distributed/service/graph_py_service.h b/paddle/fluid/distributed/service/graph_py_service.h index 44147924ba49d..beab26085058b 100644 --- a/paddle/fluid/distributed/service/graph_py_service.h +++ b/paddle/fluid/distributed/service/graph_py_service.h @@ -150,7 +150,7 @@ class GraphPyClient : public GraphPyService { std::string name, std::vector node_ids, int sample_size); std::vector random_sample_nodes(std::string name, int server_index, int sample_size); - std::vector pull_graph_list(std::string name, int server_index, + std::vector pull_graph_list(std::string name, int server_index, int start, int size, int step = 1); ::paddle::distributed::PSParameter GetWorkerProto(); diff --git a/paddle/fluid/distributed/service/ps_client.h b/paddle/fluid/distributed/service/ps_client.h index 1b61e35641322..1b50a0c37098d 100644 --- a/paddle/fluid/distributed/service/ps_client.h +++ b/paddle/fluid/distributed/service/ps_client.h @@ -167,7 +167,7 @@ class PSClient { virtual std::future pull_graph_list(uint32_t table_id, int server_index, int start, int size, int step, - std::vector &res) { + std::vector &res) { LOG(FATAL) << "Did not implement"; std::promise promise; std::future fut = promise.get_future(); diff --git a/paddle/fluid/distributed/table/common_graph_table.cc b/paddle/fluid/distributed/table/common_graph_table.cc index 7d67775635542..7774a402cb49e 100644 --- a/paddle/fluid/distributed/table/common_graph_table.cc +++ b/paddle/fluid/distributed/table/common_graph_table.cc @@ -23,9 +23,9 @@ namespace paddle { namespace distributed { -std::vector GraphShard::get_batch(int start, int end, int step) { +std::vector GraphShard::get_batch(int start, int end, int step) { if (start < 0) start = 0; - std::vector res; + std::vector res; for (int pos = start; pos < std::min(end, (int)bucket.size()); pos += step) { res.push_back(bucket[pos]); } @@ -34,21 +34,29 @@ std::vector GraphShard::get_batch(int start, int end, int step) { size_t GraphShard::get_size() { return bucket.size(); } -GraphNode *GraphShard::add_node(uint64_t id, std::string feature) { - if (node_location.find(id) != node_location.end()) - return bucket[node_location[id]]; - node_location[id] = bucket.size(); - bucket.push_back(new GraphNode(id, feature)); - return bucket.back(); +GraphNode *GraphShard::add_graph_node(uint64_t id) { + if (node_location.find(id) == node_location.end()){ + node_location[id] = bucket.size(); + bucket.push_back(new GraphNode(id)); + } + return (GraphNode*)bucket[node_location[id]]; +} + +FeatureNode *GraphShard::add_feature_node(uint64_t id) { + if (node_location.find(id) == node_location.end()){ + node_location[id] = bucket.size(); + bucket.push_back(new FeatureNode(id)); + } + return (FeatureNode*)bucket[node_location[id]]; } void GraphShard::add_neighboor(uint64_t id, uint64_t dst_id, float weight) { - add_node(id, std::string(""))->add_edge(dst_id, weight); + find_node(id)->add_edge(dst_id, weight); } -GraphNode *GraphShard::find_node(uint64_t id) { +Node *GraphShard::find_node(uint64_t id) { auto iter = node_location.find(id); - return iter == node_location.end() ? NULL : bucket[iter->second]; + return iter == node_location.end() ? nullptr : bucket[iter->second]; } int32_t GraphTable::load(const std::string &path, const std::string ¶m) { @@ -132,9 +140,10 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) { } size_t index = shard_id - shard_start; if (feature.size() > 0) { - shards[index].add_node(id, paddle::string::join_strings(feature, '\t')); + // TODO add feature + shards[index].add_feature_node(id); } else { - shards[index].add_node(id, std::string("")); + shards[index].add_feature_node(id); } } } @@ -175,7 +184,7 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) { } size_t index = src_shard_id - shard_start; - shards[index].add_node(src_id, std::string(""))->build_edges(is_weighted); + shards[index].add_graph_node(src_id)->build_edges(is_weighted); shards[index].add_neighboor(src_id, dst_id, weight); } } @@ -192,13 +201,13 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) { return 0; } -GraphNode *GraphTable::find_node(uint64_t id) { +Node *GraphTable::find_node(uint64_t id) { size_t shard_id = id % shard_num; if (shard_id >= shard_end || shard_id < shard_start) { return NULL; } size_t index = shard_id - shard_start; - GraphNode *node = shards[index].find_node(id); + Node *node = shards[index].find_node(id); return node; } uint32_t GraphTable::get_thread_pool_index(uint64_t node_id) { @@ -282,7 +291,7 @@ int GraphTable::random_sample_neighboors( int &actual_size = actual_sizes[idx]; tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]->enqueue( [&]() -> int { - GraphNode *node = find_node(node_id); + Node *node = find_node(node_id); if (node == NULL) { actual_size = 0; @@ -290,7 +299,7 @@ int GraphTable::random_sample_neighboors( } std::vector res = node->sample_k(sample_size); actual_size = - res.size() * (GraphNode::id_size + GraphNode::weight_size); + res.size() * (Node::id_size + Node::weight_size); int offset = 0; uint64_t id; float weight; @@ -299,10 +308,10 @@ int GraphTable::random_sample_neighboors( for (int &x : res) { id = node->get_neighbor_id(x); weight = node->get_neighbor_weight(x); - memcpy(buffer_addr + offset, &id, GraphNode::id_size); - offset += GraphNode::id_size; - memcpy(buffer_addr + offset, &weight, GraphNode::weight_size); - offset += GraphNode::weight_size; + memcpy(buffer_addr + offset, &id, Node::id_size); + offset += Node::id_size; + memcpy(buffer_addr + offset, &weight, Node::weight_size); + offset += Node::weight_size; } return 0; })); @@ -318,7 +327,7 @@ int32_t GraphTable::pull_graph_list(int start, int total_size, int step) { if (start < 0) start = 0; int size = 0, cur_size; - std::vector>> tasks; + std::vector>> tasks; for (size_t i = 0; i < shards.size() && total_size > 0; i++) { cur_size = shards[i].get_size(); if (size + cur_size <= start) { @@ -328,7 +337,7 @@ int32_t GraphTable::pull_graph_list(int start, int total_size, int count = std::min(1 + (size + cur_size - start - 1) / step, total_size); int end = start + (count - 1) * step + 1; tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue( - [this, i, start, end, step, size]() -> std::vector { + [this, i, start, end, step, size]() -> std::vector { return this->shards[i].get_batch(start - size, end - size, step); })); @@ -340,7 +349,7 @@ int32_t GraphTable::pull_graph_list(int start, int total_size, tasks[i].wait(); } size = 0; - std::vector> res; + std::vector> res; for (size_t i = 0; i < tasks.size(); i++) { res.push_back(tasks[i].get()); for (size_t j = 0; j < res.back().size(); j++) { diff --git a/paddle/fluid/distributed/table/common_graph_table.h b/paddle/fluid/distributed/table/common_graph_table.h index 8d03b70f69adb..b13cb69c47b7a 100644 --- a/paddle/fluid/distributed/table/common_graph_table.h +++ b/paddle/fluid/distributed/table/common_graph_table.h @@ -44,8 +44,8 @@ class GraphShard { // bucket_size = init_bucket_size(shard_num); // bucket.resize(bucket_size); } - std::vector &get_bucket() { return bucket; } - std::vector get_batch(int start, int end, int step); + std::vector &get_bucket() { return bucket; } + std::vector get_batch(int start, int end, int step); // int init_bucket_size(int shard_num) { // for (int i = bucket_low_bound;; i++) { // if (gcd(i, shard_num) == 1) return i; @@ -59,8 +59,9 @@ class GraphShard { } return res; } - GraphNode *add_node(uint64_t id, std::string feature); - GraphNode *find_node(uint64_t id); + GraphNode *add_graph_node(uint64_t id); + FeatureNode *add_feature_node(uint64_t id); + Node *find_node(uint64_t id); void add_neighboor(uint64_t id, uint64_t dst_id, float weight); // std::unordered_map::iterator> std::unordered_map get_node_location() { @@ -70,7 +71,7 @@ class GraphShard { private: std::unordered_map node_location; int shard_num; - std::vector bucket; + std::vector bucket; }; class GraphTable : public SparseTable { public: @@ -98,8 +99,8 @@ class GraphTable : public SparseTable { int32_t load_edges(const std::string &path, bool reverse); int32_t load_nodes(const std::string &path, std::string node_type); - - GraphNode *find_node(uint64_t id); + + Node *find_node(uint64_t id); virtual int32_t pull_sparse(float *values, const uint64_t *keys, size_t num) { return 0; diff --git a/paddle/fluid/distributed/table/graph_node.cc b/paddle/fluid/distributed/table/graph_node.cc index 4e4e210cccec6..8c20fc302f8d7 100644 --- a/paddle/fluid/distributed/table/graph_node.cc +++ b/paddle/fluid/distributed/table/graph_node.cc @@ -29,12 +29,37 @@ GraphNode::~GraphNode() { } } -int GraphNode::weight_size = sizeof(float); -int GraphNode::id_size = sizeof(uint64_t); -int GraphNode::int_size = sizeof(int); -int GraphNode::get_size(bool need_feature) { - return id_size + int_size + (need_feature ? feature.size() : 0); +int Node::weight_size = sizeof(float); +int Node::id_size = sizeof(uint64_t); +int Node::int_size = sizeof(int); + +int Node::get_size(bool need_feature) { + return id_size + int_size; } + +void Node::to_buffer(char* buffer, bool need_feature) { + memcpy(buffer, &id, id_size); + buffer += id_size; + + int feat_num = 0; + memcpy(buffer, &feat_num, sizeof(int)); +} + +void Node::recover_from_buffer(char* buffer) { + memcpy(&id, buffer, id_size); +} + +int FeatureNode::get_size(bool need_feature) { + int size = id_size + int_size; // id, feat_num + if (need_feature){ + size += feature.size() * int_size; + for (const std::string& fea: feature){ + size += fea.size(); + } + } + return size; +} + void GraphNode::build_edges(bool is_weighted) { if (edges == nullptr){ if (is_weighted == true){ @@ -52,28 +77,48 @@ void GraphNode::build_sampler(std::string sample_type) { } sampler->build(edges); } -void GraphNode::to_buffer(char* buffer, bool need_feature) { - int size = get_size(need_feature); - memcpy(buffer, &size, int_size); +void FeatureNode::to_buffer(char* buffer, bool need_feature) { + memcpy(buffer, &id, id_size); + buffer += id_size; + + int feat_num = 0; + int feat_len; if (need_feature) { - memcpy(buffer + int_size, feature.c_str(), feature.size()); - memcpy(buffer + int_size + feature.size(), &id, id_size); + feat_num += feature.size(); + memcpy(buffer, &feat_num, sizeof(int)); + buffer += sizeof(int); + for (int i = 0; i < feat_num; ++i){ + feat_len = feature[i].size(); + memcpy(buffer, &feat_len, sizeof(int)); + buffer += sizeof(int); + memcpy(buffer, feature[i].c_str(), feature[i].size()); + buffer += feature[i].size(); + } } else { - memcpy(buffer + int_size, &id, id_size); + memcpy(buffer, &feat_num, sizeof(int)); } } -void GraphNode::recover_from_buffer(char* buffer) { - int size; - memcpy(&size, buffer, int_size); - int feature_size = size - id_size - int_size; - char str[feature_size + 1]; - memcpy(str, buffer + int_size, feature_size); - str[feature_size] = '\0'; - feature = str; - memcpy(&id, buffer + int_size + feature_size, id_size); - // int int_state; - // memcpy(&int_state, buffer + int_size + feature_size + id_size, enum_size); - // type = GraphNodeType(int_state); +void FeatureNode::recover_from_buffer(char* buffer) { + + int feat_num, feat_len; + memcpy(&id, buffer, id_size); + buffer += id_size; + + memcpy(&feat_num, buffer, sizeof(int)); + buffer += sizeof(int); + + feature.clear(); + for (int i = 0; i < feat_num; ++i) { + memcpy(&feat_len, buffer, sizeof(int)); + buffer += sizeof(int); + + char str[feat_len + 1]; + memcpy(str, buffer, feat_len); + buffer += feat_len; + str[feat_len] = '\0'; + feature.push_back(std::string(str)); + } + } } } diff --git a/paddle/fluid/distributed/table/graph_node.h b/paddle/fluid/distributed/table/graph_node.h index 4d2b866d5d822..74fb110830321 100644 --- a/paddle/fluid/distributed/table/graph_node.h +++ b/paddle/fluid/distributed/table/graph_node.h @@ -18,32 +18,68 @@ namespace paddle { namespace distributed { -class GraphNode { +class Node { public: - GraphNode(): sampler(nullptr), edges(nullptr) { } - GraphNode(uint64_t id, std::string feature) - : id(id), feature(feature), sampler(nullptr), edges(nullptr) {} - virtual ~GraphNode(); + Node(){} + Node(uint64_t id) + : id(id) {} + virtual ~Node() {} static int id_size, int_size, weight_size; uint64_t get_id() { return id; } void set_id(uint64_t id) { this->id = id; } - void set_feature(std::string feature) { this->feature = feature; } - std::string get_feature() { return feature; } + + virtual void build_edges(bool is_weighted) {} + virtual void build_sampler(std::string sample_type) {} + virtual void add_edge(uint64_t id, float weight) {} + virtual std::vector sample_k(int k) { return std::vector(); } + virtual uint64_t get_neighbor_id(int idx){ return 0; } + virtual float get_neighbor_weight(int idx){ return 1.; } + virtual int get_size(bool need_feature); - virtual void build_edges(bool is_weighted); - virtual void build_sampler(std::string sample_type); virtual void to_buffer(char *buffer, bool need_feature); virtual void recover_from_buffer(char *buffer); - virtual void add_edge(uint64_t id, float weight) { edges->add_edge(id, weight); } - std::vector sample_k(int k) { return sampler->sample_k(k); } - uint64_t get_neighbor_id(int idx){return edges->get_id(idx);} - float get_neighbor_weight(int idx){return edges->get_weight(idx);} + virtual void add_feature(std::string feature) { } + virtual std::string get_feature(int idx) { return std::string(""); } protected: uint64_t id; - std::string feature; + +}; + +class GraphNode: public Node { + public: + GraphNode(): Node(), sampler(nullptr), edges(nullptr) { } + GraphNode(uint64_t id) + : Node(id), sampler(nullptr), edges(nullptr) {} + virtual ~GraphNode(); + virtual void build_edges(bool is_weighted); + virtual void build_sampler(std::string sample_type); + virtual void add_edge(uint64_t id, float weight) { edges->add_edge(id, weight); } + virtual std::vector sample_k(int k) { return sampler->sample_k(k); } + virtual uint64_t get_neighbor_id(int idx){return edges->get_id(idx);} + virtual float get_neighbor_weight(int idx){return edges->get_weight(idx);} + + protected: Sampler *sampler; GraphEdgeBlob * edges; }; + + +class FeatureNode: public Node{ + public: + FeatureNode(): Node() { } + FeatureNode(uint64_t id) : Node(id) {} + virtual ~FeatureNode() {} + virtual int get_size(bool need_feature); + virtual void to_buffer(char *buffer, bool need_feature); + virtual void recover_from_buffer(char *buffer); + virtual void add_feature(std::string feature) { this->feature.push_back(feature); } + virtual std::string get_feature(int idx) { return feature[idx]; } + + protected: + std::vector feature; +}; + + } } diff --git a/paddle/fluid/distributed/test/graph_node_test.cc b/paddle/fluid/distributed/test/graph_node_test.cc index 2ba5946cc443f..508d9b63f3d5f 100644 --- a/paddle/fluid/distributed/test/graph_node_test.cc +++ b/paddle/fluid/distributed/test/graph_node_test.cc @@ -334,7 +334,7 @@ void RunBrpcPushSparse() { pull_status.wait(); ASSERT_EQ(0, vs[0].size()); - std::vector nodes; + std::vector nodes; pull_status = worker_ptr_->pull_graph_list(0, 0, 0, 1, 1, nodes); pull_status.wait(); ASSERT_EQ(nodes.size(), 1); @@ -450,15 +450,15 @@ void RunBrpcPushSparse() { void testGraphToBuffer() { ::paddle::distributed::GraphNode s, s1; - s.set_feature("hhhh"); + s.add_feature("hhhh"); s.set_id(65); int size = s.get_size(true); char str[size]; s.to_buffer(str, true); s1.recover_from_buffer(str); ASSERT_EQ(s.get_id(), s1.get_id()); - VLOG(0) << s.get_feature(); - VLOG(0) << s1.get_feature(); + VLOG(0) << s.get_feature(0); + VLOG(0) << s1.get_feature(0); } TEST(RunBrpcPushSparse, Run) { RunBrpcPushSparse(); }