From 30b0226ea970671c81097b51f3cb87e9f91af5aa Mon Sep 17 00:00:00 2001 From: yelrose <270018958@qq.com> Date: Fri, 6 Aug 2021 07:48:31 +0000 Subject: [PATCH] add remove graph node; add set_feature --- paddle/fluid/distributed/service/graph_brpc_client.cc | 1 + paddle/fluid/distributed/service/graph_brpc_client.h | 1 + paddle/fluid/distributed/service/graph_brpc_server.cc | 2 ++ paddle/fluid/distributed/service/graph_brpc_server.h | 1 + paddle/fluid/distributed/service/graph_py_service.cc | 1 + paddle/fluid/distributed/service/graph_py_service.h | 1 + paddle/fluid/distributed/table/common_graph_table.cc | 1 + paddle/fluid/distributed/table/common_graph_table.h | 1 + paddle/fluid/distributed/test/graph_node_test.cc | 1 + paddle/fluid/pybind/fleet_py.cc | 1 + 10 files changed, 11 insertions(+) diff --git a/paddle/fluid/distributed/service/graph_brpc_client.cc b/paddle/fluid/distributed/service/graph_brpc_client.cc index a724662dbcf9b..8dea0190ae837 100644 --- a/paddle/fluid/distributed/service/graph_brpc_client.cc +++ b/paddle/fluid/distributed/service/graph_brpc_client.cc @@ -27,6 +27,7 @@ namespace paddle { namespace distributed { + void GraphPsService_Stub::service( ::google::protobuf::RpcController *controller, const ::paddle::distributed::PsRequestMessage *request, diff --git a/paddle/fluid/distributed/service/graph_brpc_client.h b/paddle/fluid/distributed/service/graph_brpc_client.h index 8acb2047b8e97..863ba80f2ac2c 100644 --- a/paddle/fluid/distributed/service/graph_brpc_client.h +++ b/paddle/fluid/distributed/service/graph_brpc_client.h @@ -84,6 +84,7 @@ class GraphBrpcClient : public BrpcPsClient { const std::vector& feature_names, const std::vector>& features); + virtual std::future clear_nodes(uint32_t table_id); virtual std::future add_graph_node( uint32_t table_id, std::vector& node_id_list, diff --git a/paddle/fluid/distributed/service/graph_brpc_server.cc b/paddle/fluid/distributed/service/graph_brpc_server.cc index 4de552af31455..972a13d8dc681 100644 --- a/paddle/fluid/distributed/service/graph_brpc_server.cc +++ b/paddle/fluid/distributed/service/graph_brpc_server.cc @@ -436,6 +436,8 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table, buffer += feat_len; } } + + ((GraphTable *)table)->set_node_feat(node_ids, feature_names, features); return 0; diff --git a/paddle/fluid/distributed/service/graph_brpc_server.h b/paddle/fluid/distributed/service/graph_brpc_server.h index 57263dfa812cc..6b4853fa67992 100644 --- a/paddle/fluid/distributed/service/graph_brpc_server.h +++ b/paddle/fluid/distributed/service/graph_brpc_server.h @@ -83,6 +83,7 @@ class GraphBrpcService : public PsBaseService { const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); + int32_t graph_get_node_feat(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); diff --git a/paddle/fluid/distributed/service/graph_py_service.cc b/paddle/fluid/distributed/service/graph_py_service.cc index 9a03f7e02ab21..c79174439501c 100644 --- a/paddle/fluid/distributed/service/graph_py_service.cc +++ b/paddle/fluid/distributed/service/graph_py_service.cc @@ -165,6 +165,7 @@ ::paddle::distributed::PSParameter GraphPyClient::GetWorkerProto() { ::paddle::distributed::DownpourWorkerParameter* downpour_worker_proto = worker_proto->mutable_downpour_worker_param(); + for (auto& tuple : this->table_id_map) { VLOG(0) << " make a new table " << tuple.second; ::paddle::distributed::TableParameter* worker_sparse_table_proto = diff --git a/paddle/fluid/distributed/service/graph_py_service.h b/paddle/fluid/distributed/service/graph_py_service.h index a8e082a1bea47..db017fe31ea1c 100644 --- a/paddle/fluid/distributed/service/graph_py_service.h +++ b/paddle/fluid/distributed/service/graph_py_service.h @@ -163,6 +163,7 @@ class GraphPyClient : public GraphPyService { int start, int size, int step = 1); ::paddle::distributed::PSParameter GetWorkerProto(); + protected: mutable std::mutex mutex_; int client_id; diff --git a/paddle/fluid/distributed/table/common_graph_table.cc b/paddle/fluid/distributed/table/common_graph_table.cc index a8a676ea569f1..75a0135e808ac 100644 --- a/paddle/fluid/distributed/table/common_graph_table.cc +++ b/paddle/fluid/distributed/table/common_graph_table.cc @@ -624,6 +624,7 @@ int32_t GraphTable::initialize() { // shards.resize(shard_num_per_table); shards = std::vector(shard_num_per_table, GraphShard(shard_num)); return 0; + } } }; diff --git a/paddle/fluid/distributed/table/common_graph_table.h b/paddle/fluid/distributed/table/common_graph_table.h index eb766989431b1..211bea5ee5a16 100644 --- a/paddle/fluid/distributed/table/common_graph_table.h +++ b/paddle/fluid/distributed/table/common_graph_table.h @@ -46,6 +46,7 @@ class GraphShard { } return res; } + GraphNode *add_graph_node(uint64_t id); FeatureNode *add_feature_node(uint64_t id); Node *find_node(uint64_t id); diff --git a/paddle/fluid/distributed/test/graph_node_test.cc b/paddle/fluid/distributed/test/graph_node_test.cc index e5e23e9faebb4..e812cc7d551a2 100644 --- a/paddle/fluid/distributed/test/graph_node_test.cc +++ b/paddle/fluid/distributed/test/graph_node_test.cc @@ -569,6 +569,7 @@ void RunBrpcPushSparse() { ASSERT_TRUE(node_feat[1][0] == "helloworld"); + // Test string node_ids.clear(); node_ids.push_back(37); diff --git a/paddle/fluid/pybind/fleet_py.cc b/paddle/fluid/pybind/fleet_py.cc index 2cb5d3a1bd46f..f261b652c72da 100644 --- a/paddle/fluid/pybind/fleet_py.cc +++ b/paddle/fluid/pybind/fleet_py.cc @@ -229,6 +229,7 @@ void BindGraphPyClient(py::module* m) { .def("bind_local_server", &GraphPyClient::bind_local_server); } + using paddle::distributed::TreeIndex; using paddle::distributed::IndexWrapper; using paddle::distributed::IndexNode;