Skip to content

Commit

Permalink
add remove graph node; add set_feature
Browse files Browse the repository at this point in the history
  • Loading branch information
Yelrose committed Aug 6, 2021
1 parent 30b0226 commit ead5c00
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 66 deletions.
20 changes: 11 additions & 9 deletions paddle/fluid/distributed/service/graph_brpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
namespace paddle {
namespace distributed {


void GraphPsService_Stub::service(
::google::protobuf::RpcController *controller,
const ::paddle::distributed::PsRequestMessage *request,
Expand Down Expand Up @@ -497,17 +496,19 @@ std::future<int32_t> GraphBrpcClient::set_node_feat(
size_t request_call_num = request2server.size();
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num);
std::vector<std::vector<std::vector<std::string>>> features_idx_buckets(request_call_num);
std::vector<std::vector<std::vector<std::string>>> features_idx_buckets(
request_call_num);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
int request_idx = server2request[server_index];
node_id_buckets[request_idx].push_back(node_ids[query_idx]);
query_idx_buckets[request_idx].push_back(query_idx);
if(features_idx_buckets[request_idx].size() == 0) {
if (features_idx_buckets[request_idx].size() == 0) {
features_idx_buckets[request_idx].resize(feature_names.size());
}
for(int feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
features_idx_buckets[request_idx][feat_idx].push_back(features[feat_idx][query_idx]);
for (int feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
features_idx_buckets[request_idx][feat_idx].push_back(
features[feat_idx][query_idx]);
}
}

Expand Down Expand Up @@ -553,10 +554,12 @@ std::future<int32_t> GraphBrpcClient::set_node_feat(
std::string set_feature = "";
for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
size_t feat_len = features_idx_buckets[request_idx][feat_idx][node_idx].size();
size_t feat_len =
features_idx_buckets[request_idx][feat_idx][node_idx].size();
set_feature.append((char *)&feat_len, sizeof(size_t));
set_feature.append(features_idx_buckets[request_idx][feat_idx][node_idx].data(),
feat_len);
set_feature.append(
features_idx_buckets[request_idx][feat_idx][node_idx].data(),
feat_len);
}
}
closure->request(request_idx)
Expand All @@ -572,7 +575,6 @@ std::future<int32_t> GraphBrpcClient::set_node_feat(
return fut;
}


int32_t GraphBrpcClient::initialize() {
// set_shard_num(_config.shard_num());
BrpcPsClient::initialize();
Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/distributed/service/graph_brpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ class GraphBrpcClient : public BrpcPsClient {
const std::vector<std::string>& feature_names,
const std::vector<std::vector<std::string>>& features);


virtual std::future<int32_t> clear_nodes(uint32_t table_id);
virtual std::future<int32_t> add_graph_node(
uint32_t table_id, std::vector<uint64_t>& node_id_list,
Expand Down
7 changes: 3 additions & 4 deletions paddle/fluid/distributed/service/graph_brpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -425,19 +425,18 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table,
std::vector<std::vector<std::string>> features(
feature_names.size(), std::vector<std::string>(node_num));

const char* buffer = request.params(2).c_str();
const char *buffer = request.params(2).c_str();

for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
size_t feat_len = *(size_t *)(buffer);
buffer += sizeof(size_t);
auto feat = std::string(buffer, feat_len);
features[feat_idx][node_idx] = feat;
features[feat_idx][node_idx] = feat;
buffer += feat_len;
}
}


((GraphTable *)table)->set_node_feat(node_ids, feature_names, features);

return 0;
Expand Down
5 changes: 2 additions & 3 deletions paddle/fluid/distributed/service/graph_py_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ ::paddle::distributed::PSParameter GraphPyClient::GetWorkerProto() {
::paddle::distributed::DownpourWorkerParameter* downpour_worker_proto =
worker_proto->mutable_downpour_worker_param();


for (auto& tuple : this->table_id_map) {
VLOG(0) << " make a new table " << tuple.second;
::paddle::distributed::TableParameter* worker_sparse_table_proto =
Expand Down Expand Up @@ -333,15 +332,15 @@ std::vector<std::vector<std::string>> GraphPyClient::get_node_feat(

void GraphPyClient::set_node_feat(
std::string node_type, std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names,
std::vector<std::string> feature_names,
const std::vector<std::vector<std::string>> features) {
if (this->table_id_map.count(node_type)) {
uint32_t table_id = this->table_id_map[node_type];
auto status =
worker_ptr->set_node_feat(table_id, node_ids, feature_names, features);
status.wait();
}
return ;
return;
}

std::vector<FeatureNode> GraphPyClient::pull_graph_list(std::string name,
Expand Down
8 changes: 3 additions & 5 deletions paddle/fluid/distributed/service/graph_py_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,15 +155,13 @@ class GraphPyClient : public GraphPyService {
std::vector<std::vector<std::string>> get_node_feat(
std::string node_type, std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names);
void set_node_feat(
std::string node_type, std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names,
const std::vector<std::vector<std::string>> features);
void set_node_feat(std::string node_type, std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names,
const std::vector<std::vector<std::string>> features);
std::vector<FeatureNode> pull_graph_list(std::string name, int server_index,
int start, int size, int step = 1);
::paddle::distributed::PSParameter GetWorkerProto();


protected:
mutable std::mutex mutex_;
int client_id;
Expand Down
63 changes: 29 additions & 34 deletions paddle/fluid/distributed/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
#include "paddle/fluid/distributed/table/common_graph_table.h"
#include <time.h>
#include <algorithm>
#include <chrono>
#include <set>
#include <sstream>
#include "paddle/fluid/distributed/common/utils.h"
#include "paddle/fluid/distributed/table/graph/graph_node.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/string/printf.h"
#include <chrono>
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/framework/generator.h"

namespace paddle {
namespace distributed {
Expand Down Expand Up @@ -406,31 +406,30 @@ int32_t GraphTable::random_sample_neighboors(
int thread_pool_index = get_thread_pool_index(node_id);
auto rng = _shards_task_rng_pool[thread_pool_index];

tasks.push_back(_shards_task_pool[thread_pool_index]->enqueue(
[&]() -> int {
Node *node = find_node(node_id);
tasks.push_back(_shards_task_pool[thread_pool_index]->enqueue([&]() -> int {
Node *node = find_node(node_id);

if (node == nullptr) {
actual_size = 0;
return 0;
}
std::vector<int> res = node->sample_k(sample_size, rng);
actual_size = res.size() * (Node::id_size + Node::weight_size);
int offset = 0;
uint64_t id;
float weight;
char *buffer_addr = new char[actual_size];
buffer.reset(buffer_addr);
for (int &x : res) {
id = node->get_neighbor_id(x);
weight = node->get_neighbor_weight(x);
memcpy(buffer_addr + offset, &id, Node::id_size);
offset += Node::id_size;
memcpy(buffer_addr + offset, &weight, Node::weight_size);
offset += Node::weight_size;
}
return 0;
}));
if (node == nullptr) {
actual_size = 0;
return 0;
}
std::vector<int> res = node->sample_k(sample_size, rng);
actual_size = res.size() * (Node::id_size + Node::weight_size);
int offset = 0;
uint64_t id;
float weight;
char *buffer_addr = new char[actual_size];
buffer.reset(buffer_addr);
for (int &x : res) {
id = node->get_neighbor_id(x);
weight = node->get_neighbor_weight(x);
memcpy(buffer_addr + offset, &id, Node::id_size);
offset += Node::id_size;
memcpy(buffer_addr + offset, &weight, Node::weight_size);
offset += Node::weight_size;
}
return 0;
}));
}
for (size_t idx = 0; idx < node_num; ++idx) {
tasks[idx].get();
Expand Down Expand Up @@ -470,10 +469,10 @@ int32_t GraphTable::get_node_feat(const std::vector<uint64_t> &node_ids,
return 0;
}


int32_t GraphTable::set_node_feat(const std::vector<uint64_t> &node_ids,
const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &res) {
int32_t GraphTable::set_node_feat(
const std::vector<uint64_t> &node_ids,
const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &res) {
size_t node_num = node_ids.size();
std::vector<std::future<int>> tasks;
for (size_t idx = 0; idx < node_num; ++idx) {
Expand All @@ -498,9 +497,6 @@ int32_t GraphTable::set_node_feat(const std::vector<uint64_t> &node_ids,
return 0;
}




std::pair<int32_t, std::string> GraphTable::parse_feature(
std::string feat_str) {
// Return (feat_id, btyes) if name are in this->feat_name, else return (-1,
Expand Down Expand Up @@ -624,7 +620,6 @@ int32_t GraphTable::initialize() {
// shards.resize(shard_num_per_table);
shards = std::vector<GraphShard>(shard_num_per_table, GraphShard(shard_num));
return 0;

}
}
};
7 changes: 4 additions & 3 deletions paddle/fluid/distributed/table/common_graph_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,10 @@ class GraphTable : public SparseTable {
const std::vector<std::string> &feature_names,
std::vector<std::vector<std::string>> &res);

virtual int32_t set_node_feat(const std::vector<uint64_t> &node_ids,
const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &res);
virtual int32_t set_node_feat(
const std::vector<uint64_t> &node_ids,
const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &res);

protected:
std::vector<GraphShard> shards;
Expand Down
7 changes: 3 additions & 4 deletions paddle/fluid/distributed/test/graph_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -560,16 +560,15 @@ void RunBrpcPushSparse() {

node_feat[1][0] = "helloworld";

client1.set_node_feat(std::string("user"), node_ids, feature_names, node_feat);
client1.set_node_feat(std::string("user"), node_ids, feature_names,
node_feat);

//sleep(5);
// sleep(5);
node_feat =
client1.get_node_feat(std::string("user"), node_ids, feature_names);
VLOG(0) << "get_node_feat: " << node_feat[1][0];
ASSERT_TRUE(node_feat[1][0] == "helloworld");



// Test string
node_ids.clear();
node_ids.push_back(37);
Expand Down
5 changes: 2 additions & 3 deletions paddle/fluid/pybind/fleet_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,19 +217,18 @@ void BindGraphPyClient(py::module* m) {
std::vector<uint64_t> node_ids,
std::vector<std::string> feature_names,
std::vector<std::vector<py::bytes>> bytes_feats) {
std::vector<std::vector<std::string >> feats(bytes_feats.size());
std::vector<std::vector<std::string>> feats(bytes_feats.size());
for (int i = 0; i < bytes_feats.size(); ++i) {
for (int j = 0; j < bytes_feats[i].size(); ++j) {
feats[i].push_back(std::string(bytes_feats[i][j]));
}
}
self.set_node_feat(node_type, node_ids, feature_names, feats);
return ;
return;
})
.def("bind_local_server", &GraphPyClient::bind_local_server);
}


using paddle::distributed::TreeIndex;
using paddle::distributed::IndexWrapper;
using paddle::distributed::IndexNode;
Expand Down

0 comments on commit ead5c00

Please sign in to comment.