diff --git a/paddle/fluid/distributed/table/common_graph_table.cc b/paddle/fluid/distributed/table/common_graph_table.cc index 8bcf77975afc5..107c619235ad1 100644 --- a/paddle/fluid/distributed/table/common_graph_table.cc +++ b/paddle/fluid/distributed/table/common_graph_table.cc @@ -64,21 +64,22 @@ size_t GraphShard::get_size() { return res; } -std::list::iterator GraphShard::add_node(GraphNode *node) { - if (node_location.find(node->get_id()) != node_location.end()) - return node_location.find(node->get_id())->second; +std::list::iterator GraphShard::add_node(uint64_t id, std::string feature) { + if (node_location.find(id) != node_location.end()) + return node_location.find(id)->second; - int index = node->get_id() % shard_num % bucket_size; + int index = id % shard_num % bucket_size; + GraphNode *node = new GraphNode(id, feature); std::list::iterator iter = bucket[index].insert(bucket[index].end(), node); - node_location[node->get_id()] = iter; + node_location[id] = iter; return iter; } void GraphShard::add_neighboor(uint64_t id, GraphEdge *edge) { - (*add_node(new GraphNode(id, std::string(""))))->add_edge(edge); + (*add_node(id, std::string("")))->add_edge(edge); } GraphNode *GraphShard::find_node(uint64_t id) { @@ -88,13 +89,55 @@ GraphNode *GraphShard::find_node(uint64_t id) { int32_t GraphTable::load(const std::string &path, const std::string ¶m) { auto cmd = paddle::string::split_string(param, "|"); - std::set cmd_set(cmd.begin(), cmd.end()); - bool load_edge = cmd_set.count(std::string("edge")); + std::set cmd_set(cmd.begin(), cmd.end()); bool reverse_edge = cmd_set.count(std::string("reverse")); - VLOG(0) << "Reverse Edge " << reverse_edge; + bool load_edge = cmd_set.count(std::string("edge")); + if(load_edge) { + return this -> load_edges(path, reverse_edge); + } + else { + return this -> load_nodes(path); + } +} + +int32_t GraphTable::load_nodes(const std::string &path) { + auto paths = paddle::string::split_string(path, ";"); + for (auto path : paths) { + std::ifstream file(path); + std::string line; + while (std::getline(file, line)) { + auto values = paddle::string::split_string(line, "\t"); + if (values.size() < 2) continue; + auto id = std::stoull(values[1]); + + + size_t shard_id = id % shard_num; + if (shard_id >= shard_end || shard_id < shard_start) { + VLOG(0) << "will not load " << id << " from " << path + << ", please check id distribution"; + continue; + + } + + std::string node_type = values[0]; + std::vector feature; + feature.push_back(node_type); + for(size_t slice = 2; slice < values.size(); slice ++) { + feature.push_back(values[slice]); + } + auto feat = paddle::string::join_strings(feature, '\t'); + size_t index = shard_id - shard_start; + shards[index].add_node(id, feat); + + } + } + return 0; +} + + +int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) { auto paths = paddle::string::split_string(path, ";"); - VLOG(0) << paths.size(); int count = 0; for (auto path : paths) { @@ -113,6 +156,7 @@ int32_t GraphTable::load(const std::string &path, const std::string ¶m) { if (values.size() == 3) { weight = std::stof(values[2]); } + size_t src_shard_id = src_id % shard_num; if (src_shard_id >= shard_end || src_shard_id < shard_start) { @@ -120,6 +164,7 @@ int32_t GraphTable::load(const std::string &path, const std::string ¶m) { << ", please check id distribution"; continue; } + size_t index = src_shard_id - shard_start; GraphEdge *edge = new GraphEdge(dst_id, weight); shards[index].add_neighboor(src_id, edge); @@ -128,6 +173,7 @@ int32_t GraphTable::load(const std::string &path, const std::string ¶m) { VLOG(0) << "Load Finished Total Edge Count " << count; // Build Sampler j + for (auto &shard : shards) { auto bucket = shard.get_bucket(); for (int i = 0; i < bucket.size(); i++) { @@ -141,6 +187,7 @@ int32_t GraphTable::load(const std::string &path, const std::string ¶m) { } return 0; } + GraphNode *GraphTable::find_node(uint64_t id) { size_t shard_id = id % shard_num; if (shard_id >= shard_end || shard_id < shard_start) { @@ -264,3 +311,4 @@ int32_t GraphTable::initialize() { } } }; + diff --git a/paddle/fluid/distributed/table/common_graph_table.h b/paddle/fluid/distributed/table/common_graph_table.h index 1f2b8c86d363b..decf5f1f20462 100644 --- a/paddle/fluid/distributed/table/common_graph_table.h +++ b/paddle/fluid/distributed/table/common_graph_table.h @@ -52,7 +52,7 @@ class GraphShard { } return -1; } - std::list::iterator add_node(GraphNode *node); + std::list::iterator add_node(uint64_t id, std::string feature); GraphNode *find_node(uint64_t id); void add_neighboor(uint64_t id, GraphEdge *edge); std::unordered_map::iterator> @@ -74,7 +74,13 @@ class GraphTable : public SparseTable { virtual int32_t random_sample(uint64_t node_id, int sampe_size, char *&buffer, int &actual_size); virtual int32_t initialize(); + int32_t load(const std::string &path, const std::string ¶m); + + int32_t load_edges(const std::string &path, bool reverse); + + int32_t load_nodes(const std::string &path); + GraphNode *find_node(uint64_t id); virtual int32_t pull_sparse(float *values, const uint64_t *keys, size_t num) {