diff --git a/paddle/fluid/distributed/ps/table/common_graph_table.cc b/paddle/fluid/distributed/ps/table/common_graph_table.cc index 995eafcdd61eb..0472d83420702 100644 --- a/paddle/fluid/distributed/ps/table/common_graph_table.cc +++ b/paddle/fluid/distributed/ps/table/common_graph_table.cc @@ -74,6 +74,11 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea( std::vector node_id_array[task_pool_size_]; std::vector node_fea_info_array[task_pool_size_]; + slot_feature_num_map_.resize(slot_num); + for (int k = 0; k < slot_num; ++k) { + slot_feature_num_map_[k] = 0; + } + for (size_t i = 0; i < bags.size(); i++) { if (bags[i].size() > 0) { tasks.push_back(_shards_task_pool[i]->enqueue([&, i, this]() -> int { @@ -94,13 +99,17 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea( int total_feature_size = 0; for (int k = 0; k < slot_num; ++k) { v->get_feature_ids(k, &feature_ids); - total_feature_size += feature_ids.size(); + int feature_ids_size = feature_ids.size(); + if (slot_feature_num_map_[k] < feature_ids_size) { + slot_feature_num_map_[k] = feature_ids_size; + } + total_feature_size += feature_ids_size; if (!feature_ids.empty()) { feature_array[i].insert(feature_array[i].end(), feature_ids.begin(), feature_ids.end()); slot_id_array[i].insert( - slot_id_array[i].end(), feature_ids.size(), k); + slot_id_array[i].end(), feature_ids_size, k); } } x.feature_size = total_feature_size; @@ -113,6 +122,13 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea( } } for (int i = 0; i < (int)tasks.size(); i++) tasks[i].get(); + + std::stringstream ss; + for (int k = 0; k < slot_num; ++k) { + ss << slot_feature_num_map_[k] << " "; + } + VLOG(0) << "slot_feature_num_map: " << ss.str(); + paddle::framework::GpuPsCommGraphFea res; uint64_t tot_len = 0; for (int i = 0; i < task_pool_size_; i++) { @@ -1210,23 +1226,27 @@ int32_t GraphTable::parse_node_and_load(std::string ntype2files, VLOG(0) << "parse node type and nodedir failed!"; return -1; } - - std::string delim = ";"; - std::string npath = node_to_nodedir[ntypes[0]]; - auto npath_list = paddle::framework::localfs_list(npath); - std::string npath_str; - if (part_num > 0 && part_num < (int)npath_list.size()) { - std::vector sub_npath_list( - npath_list.begin(), npath_list.begin() + part_num); - npath_str = paddle::string::join_strings(sub_npath_list, delim); - } else { - npath_str = paddle::string::join_strings(npath_list, delim); - } if (ntypes.size() == 0) { VLOG(0) << "node_type not specified, nothing will be loaded "; return 0; } + std::string delim = ";"; + std::vector type_npath_strs; + for (size_t i = 0; i 0 && part_num < (int)npath_list.size()) { + std::vector sub_npath_list( + npath_list.begin(), npath_list.begin() + part_num); + type_npath_str = paddle::string::join_strings(sub_npath_list, delim); + } else { + type_npath_str = paddle::string::join_strings(npath_list, delim); + } + type_npath_strs.push_back(type_npath_str); + } + std::string npath_str = paddle::string::join_strings(type_npath_strs, delim); if (FLAGS_graph_load_in_parallel) { this->load_nodes(npath_str, ""); } else { @@ -1303,7 +1323,6 @@ int32_t GraphTable::load_node_and_edge_file(std::string etype2files, VLOG(0) << "node_type not specified, nothing will be loaded "; return 0; } - if (FLAGS_graph_load_in_parallel) { this->load_nodes(npath_str, ""); } else { @@ -1985,9 +2004,14 @@ int GraphTable::parse_feature(int idx, // "") thread_local std::vector fields; fields.clear(); - const char c = feature_separator_.at(0); + char c = slot_feature_separator_.at(0); paddle::string::split_string_ptr(feat_str, len, c, &fields); + thread_local std::vector fea_fields; + fea_fields.clear(); + c = feature_separator_.at(0); + paddle::string::split_string_ptr(fields[1].ptr, fields[1].len, c, &fea_fields); + std::string name = fields[0].to_string(); auto it = feat_id_map[idx].find(name); if (it != feat_id_map[idx].end()) { @@ -1998,26 +2022,26 @@ int GraphTable::parse_feature(int idx, // string_vector_2_string(fields.begin() + 1, fields.end(), ' ', // fea_ptr); FeatureNode::parse_value_to_bytes( - fields.begin() + 1, fields.end(), fea_ptr); + fea_fields.begin(), fea_fields.end(), fea_ptr); return 0; } else if (dtype == "string") { - string_vector_2_string(fields.begin() + 1, fields.end(), ' ', fea_ptr); + string_vector_2_string(fea_fields.begin(), fea_fields.end(), ' ', fea_ptr); return 0; } else if (dtype == "float32") { FeatureNode::parse_value_to_bytes( - fields.begin() + 1, fields.end(), fea_ptr); + fea_fields.begin(), fea_fields.end(), fea_ptr); return 0; } else if (dtype == "float64") { FeatureNode::parse_value_to_bytes( - fields.begin() + 1, fields.end(), fea_ptr); + fea_fields.begin(), fea_fields.end(), fea_ptr); return 0; } else if (dtype == "int32") { FeatureNode::parse_value_to_bytes( - fields.begin() + 1, fields.end(), fea_ptr); + fea_fields.begin(), fea_fields.end(), fea_ptr); return 0; } else if (dtype == "int64") { FeatureNode::parse_value_to_bytes( - fields.begin() + 1, fields.end(), fea_ptr); + fea_fields.begin(), fea_fields.end(), fea_ptr); return 0; } } else { @@ -2254,6 +2278,10 @@ void GraphTable::set_feature_separator(const std::string &ch) { feature_separator_ = ch; } +void GraphTable::set_slot_feature_separator(const std::string &ch) { + slot_feature_separator_ = ch; +} + int32_t GraphTable::get_server_index_by_id(uint64_t id) { return id % shard_num / shard_num_per_server; } diff --git a/paddle/fluid/distributed/ps/table/common_graph_table.h b/paddle/fluid/distributed/ps/table/common_graph_table.h index 9b416b30d9788..d122df15a6582 100644 --- a/paddle/fluid/distributed/ps/table/common_graph_table.h +++ b/paddle/fluid/distributed/ps/table/common_graph_table.h @@ -688,6 +688,7 @@ class GraphTable : public Table { int32_t make_complementary_graph(int idx, int64_t byte_size); int32_t dump_edges_to_ssd(int idx); int32_t get_partition_num(int idx) { return partitions[idx].size(); } + std::vector slot_feature_num_map() const { return slot_feature_num_map_; } std::vector get_partition(int idx, int index) { if (idx >= (int)partitions.size() || index >= (int)partitions[idx].size()) return std::vector(); @@ -705,6 +706,7 @@ class GraphTable : public Table { #endif virtual int32_t add_comm_edge(int idx, uint64_t src_id, uint64_t dst_id); virtual int32_t build_sampler(int idx, std::string sample_type = "random"); + void set_slot_feature_separator(const std::string &ch); void set_feature_separator(const std::string &ch); void build_graph_total_keys(); @@ -751,7 +753,9 @@ class GraphTable : public Table { // std::shared_ptr graph_sampler; // REGISTER_GRAPH_FRIEND_CLASS(2, CompleteGraphSampler, BasicBfsGraphSampler) #endif + std::string slot_feature_separator_ = std::string(" "); std::string feature_separator_ = std::string(" "); + std::vector slot_feature_num_map_; }; /* diff --git a/paddle/fluid/framework/data_feed.cu b/paddle/fluid/framework/data_feed.cu index 074fa407c5b51..38513c96cf964 100644 --- a/paddle/fluid/framework/data_feed.cu +++ b/paddle/fluid/framework/data_feed.cu @@ -325,22 +325,29 @@ __global__ void GraphFillSlotKernel(uint64_t *id_tensor, uint64_t *feature_buf, int len, int total_ins, - int slot_num) { + int slot_num, + int* slot_feature_num_map, + int fea_num_per_node, + int* actual_slot_id_map, + int* fea_offset_map) { CUDA_KERNEL_LOOP(idx, len) { - int slot_idx = idx / total_ins; + int fea_idx = idx / total_ins; int ins_idx = idx % total_ins; - ((uint64_t *)(id_tensor[slot_idx]))[ins_idx] = - feature_buf[ins_idx * slot_num + slot_idx]; + int actual_slot_id = actual_slot_id_map[fea_idx]; + int fea_offset = fea_offset_map[fea_idx]; + ((uint64_t *)(id_tensor[actual_slot_id]))[ins_idx * slot_feature_num_map[actual_slot_id] + fea_offset] + = feature_buf[ins_idx * fea_num_per_node + fea_idx]; } } __global__ void GraphFillSlotLodKernelOpt(uint64_t *id_tensor, int len, - int total_ins) { + int total_ins, + int* slot_feature_num_map) { CUDA_KERNEL_LOOP(idx, len) { int slot_idx = idx / total_ins; int ins_idx = idx % total_ins; - ((uint64_t *)(id_tensor[slot_idx]))[ins_idx] = ins_idx; + ((uint64_t *)(id_tensor[slot_idx]))[ins_idx] = ins_idx * slot_feature_num_map[slot_idx]; } } @@ -396,7 +403,7 @@ int GraphDataGenerator::FillGraphSlotFeature(int total_instance, int64_t *slot_lod_tensor_ptr_[slot_num_]; for (int i = 0; i < slot_num_; ++i) { slot_tensor_ptr_[i] = feed_vec_[3 + 2 * i]->mutable_data( - {total_instance, 1}, this->place_); + {total_instance * h_slot_feature_num_map_[i], 1}, this->place_); slot_lod_tensor_ptr_[i] = feed_vec_[3 + 2 * i + 1]->mutable_data( {total_instance + 1}, this->place_); } @@ -422,43 +429,48 @@ int GraphDataGenerator::FillGraphSlotFeature(int total_instance, train_stream_); uint64_t *feature_buf = reinterpret_cast(d_feature_buf_->ptr()); FillFeatureBuf(ins_cursor, feature_buf, total_instance); - GraphFillSlotKernel<<>>((uint64_t *)d_slot_tensor_ptr_->ptr(), feature_buf, - total_instance * slot_num_, + total_instance * fea_num_per_node_, total_instance, - slot_num_); + slot_num_, + (int*)d_slot_feature_num_map_->ptr(), + fea_num_per_node_, + (int*)d_actual_slot_id_map_->ptr(), + (int*)d_fea_offset_map_->ptr()); GraphFillSlotLodKernelOpt<<>>( (uint64_t *)d_slot_lod_tensor_ptr_->ptr(), (total_instance + 1) * slot_num_, - total_instance + 1); + total_instance + 1, + (int*)d_slot_feature_num_map_->ptr()); if (debug_mode_) { uint64_t h_walk[total_instance]; cudaMemcpy(h_walk, ins_cursor, total_instance * sizeof(uint64_t), cudaMemcpyDeviceToHost); - uint64_t h_feature[total_instance * slot_num_]; + uint64_t h_feature[total_instance * slot_num_ * fea_num_per_node_]; cudaMemcpy(h_feature, feature_buf, - total_instance * slot_num_ * sizeof(uint64_t), + total_instance * fea_num_per_node_ * slot_num_ * sizeof(uint64_t), cudaMemcpyDeviceToHost); for (int i = 0; i < total_instance; ++i) { std::stringstream ss; - for (int j = 0; j < slot_num_; ++j) { - ss << h_feature[i * slot_num_ + j] << " "; + for (int j = 0; j < fea_num_per_node_; ++j) { + ss << h_feature[i * fea_num_per_node_ + j] << " "; } VLOG(2) << "aft FillFeatureBuf, gpu[" << gpuid_ << "] walk[" << i - << "] = " << (uint64_t)h_walk[i] << " feature[" << i * slot_num_ - << ".." << (i + 1) * slot_num_ << "] = " << ss.str(); + << "] = " << (uint64_t)h_walk[i] << " feature[" << i * fea_num_per_node_ + << ".." << (i + 1) * fea_num_per_node_ << "] = " << ss.str(); } - uint64_t h_slot_tensor[slot_num_][total_instance]; + uint64_t h_slot_tensor[fea_num_per_node_][total_instance]; uint64_t h_slot_lod_tensor[slot_num_][total_instance + 1]; for (int i = 0; i < slot_num_; ++i) { cudaMemcpy(h_slot_tensor[i], @@ -593,7 +605,6 @@ int GraphDataGenerator::GenerateBatch() { cudaStreamSynchronize(train_stream_); if (!gpu_graph_training_) return 1; ins_buf_pair_len_ -= total_instance / 2; - return 1; } @@ -833,7 +844,8 @@ int GraphDataGenerator::FillFeatureBuf(uint64_t *d_walk, auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); int ret = gpu_graph_ptr->get_feature_of_nodes( - gpuid_, d_walk, d_feature, key_num, slot_num_); + gpuid_, d_walk, d_feature, key_num, slot_num_, + (int*)d_slot_feature_num_map_->ptr(), fea_num_per_node_); return ret; } @@ -847,7 +859,9 @@ int GraphDataGenerator::FillFeatureBuf( (uint64_t *)d_walk->ptr(), (uint64_t *)d_feature->ptr(), buf_size_, - slot_num_); + slot_num_, + (int*)d_slot_feature_num_map_->ptr(), + fea_num_per_node_); return ret; } @@ -1283,7 +1297,7 @@ int GraphDataGenerator::FillWalkBuf() { size_t batch = 0; d_feature_list_ = memory::AllocShared( place_, - once_sample_startid_len_ * slot_num_ * sizeof(uint64_t), + once_sample_startid_len_ * fea_num_per_node_ * sizeof(uint64_t), phi::Stream(reinterpret_cast(sample_stream_))); uint64_t *d_feature_list_ptr = reinterpret_cast(d_feature_list_->ptr()); @@ -1296,9 +1310,11 @@ int GraphDataGenerator::FillWalkBuf() { d_uniq_node_ptr + cursor, d_feature_list_ptr, batch, - slot_num_); + slot_num_, + (int*)d_slot_feature_num_map_->ptr(), + fea_num_per_node_); if (InsertTable( - d_feature_list_ptr, slot_num_ * batch, d_uniq_fea_num)) { + d_feature_list_ptr, fea_num_per_node_ * batch, d_uniq_fea_num)) { CopyFeaFromTable(d_uniq_fea_num); table_->clear(sample_stream_); cudaMemsetAsync( @@ -1368,6 +1384,32 @@ void GraphDataGenerator::AllocResource(int thread_id, h_device_keys_len_.push_back(h_graph_all_type_keys_len[i][thread_id]); } VLOG(2) << "h_device_keys size: " << h_device_keys_len_.size(); + + h_slot_feature_num_map_ = gpu_graph_ptr->slot_feature_num_map(); + fea_num_per_node_ = 0; + for (int i = 0; i < slot_num_; ++i) { + fea_num_per_node_ += h_slot_feature_num_map_[i]; + } + std::vector h_actual_slot_id_map, h_fea_offset_map; + h_actual_slot_id_map.resize(fea_num_per_node_); + h_fea_offset_map.resize(fea_num_per_node_); + for (int slot_id = 0, fea_idx = 0; slot_id < slot_num_; ++slot_id) { + for (int j = 0; j < h_slot_feature_num_map_[slot_id]; ++j, ++fea_idx) { + h_actual_slot_id_map[fea_idx] = slot_id; + h_fea_offset_map[fea_idx] = j; + } + } + + d_slot_feature_num_map_ = memory::Alloc(place_, slot_num_ * sizeof(int)); + cudaMemcpy(d_slot_feature_num_map_->ptr(), h_slot_feature_num_map_.data(), + sizeof(int) * slot_num_, cudaMemcpyHostToDevice); + d_actual_slot_id_map_ = memory::Alloc(place_, fea_num_per_node_ * sizeof(int)); + cudaMemcpy(d_actual_slot_id_map_->ptr(), h_actual_slot_id_map.data(), + sizeof(int) * fea_num_per_node_, cudaMemcpyHostToDevice); + d_fea_offset_map_ = memory::Alloc(place_, fea_num_per_node_ * sizeof(int)); + cudaMemcpy(d_fea_offset_map_->ptr(), h_fea_offset_map.data(), + sizeof(int) * fea_num_per_node_, cudaMemcpyHostToDevice); + size_t once_max_sample_keynum = walk_degree_ * once_sample_startid_len_; d_prefix_sum_ = memory::AllocShared( place_, @@ -1429,7 +1471,7 @@ void GraphDataGenerator::AllocResource(int thread_id, memory::AllocShared(place_, (batch_size_ * 2 * 2) * sizeof(uint64_t)); if (slot_num_ > 0) { d_feature_buf_ = memory::AllocShared( - place_, (batch_size_ * 2 * 2) * slot_num_ * sizeof(uint64_t)); + place_, (batch_size_ * 2 * 2) * fea_num_per_node_ * sizeof(uint64_t)); } d_pair_num_ = memory::AllocShared(place_, sizeof(int)); diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index 7ae03968598aa..d4b701a79ea96 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -971,6 +971,9 @@ class GraphDataGenerator { std::shared_ptr d_len_per_row_; std::shared_ptr d_random_row_; std::shared_ptr d_uniq_node_num_; + std::shared_ptr d_slot_feature_num_map_; + std::shared_ptr d_actual_slot_id_map_; + std::shared_ptr d_fea_offset_map_; // std::vector> d_sampleidx2rows_; int cur_sampleidx2row_; @@ -991,6 +994,8 @@ class GraphDataGenerator { BufState buf_state_; int batch_size_; int slot_num_; + std::vector h_slot_feature_num_map_; + int fea_num_per_node_; int shuffle_seed_; int debug_mode_; bool gpu_graph_training_; diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h index d890b74ff9b2e..3799fa55bb503 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h @@ -148,7 +148,8 @@ class GpuPsGraphTable bool cpu_query_switch); int get_feature_of_nodes( - int gpu_id, uint64_t *d_walk, uint64_t *d_offset, int size, int slot_num); + int gpu_id, uint64_t *d_walk, uint64_t *d_offset, int size, int slot_num, + int* d_slot_feature_num_map, int fea_num_per_node); NodeQueryResult query_node_list(int gpu_id, int idx, diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu index eac888383a3de..384cc1bfb3eaa 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu @@ -82,36 +82,45 @@ __global__ void get_features_kernel(GpuPsCommGraphFea graph, GpuPsFeaInfo* fea_info_array, int* actual_size, uint64_t* feature, + int* slot_feature_num_map, int slot_num, - int n) { + int n, + int fea_num_per_node) { int idx = blockIdx.x * blockDim.y + threadIdx.y; if (idx < n) { int feature_size = fea_info_array[idx].feature_size; - int offset = idx * slot_num; + int src_offset = fea_info_array[idx].feature_offset; + int dst_offset = idx * fea_num_per_node; + uint64_t* dst_feature = &feature[dst_offset]; if (feature_size == 0) { - for (int k = 0; k < slot_num; ++k) { - feature[offset + k] = 0; + for (int k = 0; k < fea_num_per_node; ++k) { + dst_feature[k] = 0; } - actual_size[idx] = slot_num; + actual_size[idx] = fea_num_per_node; return; } - uint64_t* feature_start = - &(graph.feature_list[fea_info_array[idx].feature_offset]); - uint8_t* slot_id_start = - &(graph.slot_id_list[fea_info_array[idx].feature_offset]); - int m = 0; - for (int k = 0; k < slot_num; ++k) { - if (m >= fea_info_array[idx].feature_size || k < slot_id_start[m]) { - feature[offset + k] = 0; - } else if (k == slot_id_start[m]) { - feature[offset + k] = feature_start[m]; - ++m; + uint64_t* feature_start = &(graph.feature_list[src_offset]); + uint8_t* slot_id_start = &(graph.slot_id_list[src_offset]); + for (int slot_id = 0, dst_fea_idx = 0, src_fea_idx = 0; slot_id < slot_num; slot_id++) { + int feature_num = slot_feature_num_map[slot_id]; + if (src_fea_idx >= feature_size || slot_id < slot_id_start[src_fea_idx]) { + for (int j = 0; j < feature_num; ++j, ++dst_fea_idx) { + dst_feature[dst_fea_idx] = 0; + } + } else if (slot_id == slot_id_start[src_fea_idx]) { + for (int j = 0; j < feature_num; ++j, ++dst_fea_idx) { + if (slot_id == slot_id_start[src_fea_idx]) { + dst_feature[dst_fea_idx] = feature_start[src_fea_idx++]; + } else { + dst_feature[dst_fea_idx] = 0; + } + } } else { assert(0); } } - actual_size[idx] = slot_num; + actual_size[idx] = fea_num_per_node; } } @@ -1049,7 +1058,9 @@ int GpuPsGraphTable::get_feature_of_nodes(int gpu_id, uint64_t* d_nodes, uint64_t* d_feature, int node_num, - int slot_num) { + int slot_num, + int* d_slot_feature_num_map, + int fea_num_per_node) { device_mutex_[gpu_id]->lock(); if (node_num == 0) { return -1; @@ -1087,7 +1098,7 @@ int GpuPsGraphTable::get_feature_of_nodes(int gpu_id, uint64_t* d_shard_keys_ptr = reinterpret_cast(d_shard_keys->ptr()); auto d_shard_vals = memory::Alloc(place, - slot_num * node_num * sizeof(uint64_t), + fea_num_per_node * node_num * sizeof(uint64_t), phi::Stream(reinterpret_cast(stream))); uint64_t* d_shard_vals_ptr = reinterpret_cast(d_shard_vals->ptr()); auto d_shard_actual_size = @@ -1118,7 +1129,7 @@ int GpuPsGraphTable::get_feature_of_nodes(int gpu_id, create_storage(gpu_id, i, shard_len * sizeof(uint64_t), - shard_len * slot_num * sizeof(uint64_t) + + shard_len * fea_num_per_node * sizeof(uint64_t) + shard_len * sizeof(uint64_t) + sizeof(int) * (shard_len + shard_len % 2)); } @@ -1161,8 +1172,10 @@ int GpuPsGraphTable::get_feature_of_nodes(int gpu_id, val_array, actual_size_array, feature_array, + d_slot_feature_num_map, slot_num, - shard_len); + shard_len, + fea_num_per_node); } for (int i = 0; i < total_gpu; ++i) { @@ -1174,7 +1187,7 @@ int GpuPsGraphTable::get_feature_of_nodes(int gpu_id, move_result_to_source_gpu(gpu_id, total_gpu, - slot_num, + fea_num_per_node, h_left, h_right, d_shard_vals_ptr, @@ -1185,7 +1198,7 @@ int GpuPsGraphTable::get_feature_of_nodes(int gpu_id, d_feature, d_shard_actual_size_ptr, d_idx_ptr, - slot_num, + fea_num_per_node, node_num); for (int i = 0; i < total_gpu; ++i) { diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu index 6e0ea83d92991..7a766413fb049 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu @@ -214,6 +214,14 @@ void GraphGpuWrapper::set_feature_separator(std::string ch) { } } +void GraphGpuWrapper::set_slot_feature_separator(std::string ch) { + slot_feature_separator_ = ch; + if (graph_table != nullptr) { + ((GpuPsGraphTable *)graph_table) + ->cpu_graph_table_->set_slot_feature_separator(slot_feature_separator_); + } +} + void GraphGpuWrapper::make_partitions(int idx, int64_t byte_size, int device_len) { @@ -356,6 +364,7 @@ void GraphGpuWrapper::init_service() { GpuPsGraphTable *g = new GpuPsGraphTable(resource, 1, id_to_edge.size()); g->init_cpu_table(table_proto); g->cpu_graph_table_->set_feature_separator(feature_separator_); + g->cpu_graph_table_->set_slot_feature_separator(slot_feature_separator_); graph_table = (char *)g; upload_task_pool.reset(new ::ThreadPool(upload_num)); } @@ -430,13 +439,16 @@ int GraphGpuWrapper::get_feature_of_nodes(int gpu_id, uint64_t *d_walk, uint64_t *d_offset, uint32_t size, - int slot_num) { + int slot_num, + int* d_slot_feature_num_map, + int fea_num_per_node) { platform::CUDADeviceGuard guard(gpu_id); PADDLE_ENFORCE_NOT_NULL(graph_table, paddle::platform::errors::InvalidArgument( "graph_table should not be null")); return ((GpuPsGraphTable *)graph_table) - ->get_feature_of_nodes(gpu_id, d_walk, d_offset, size, slot_num); + ->get_feature_of_nodes(gpu_id, d_walk, d_offset, size, slot_num, + d_slot_feature_num_map, fea_num_per_node); } NeighborSampleResult GraphGpuWrapper::graph_neighbor_sample( @@ -514,6 +526,11 @@ void GraphGpuWrapper::load_node_weight(int type_id, int idx, std::string path) { ->cpu_graph_table_->load_node_weight(type_id, idx, path); } +std::vector GraphGpuWrapper::slot_feature_num_map() const { + return ((GpuPsGraphTable *)graph_table) + ->cpu_graph_table_->slot_feature_num_map(); +} + void GraphGpuWrapper::export_partition_files(int idx, std::string file_path) { return ((GpuPsGraphTable *)graph_table) ->cpu_graph_table_->export_partition_files(idx, file_path); diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h index 8f7c29f487f4e..5a99498d3e160 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h @@ -116,12 +116,18 @@ class GraphGpuWrapper { int idx, std::vector& key, int sample_size); + std::vector> get_edge_type_graph( + int gpu_id, int edge_type_len); + std::vector slot_feature_num_map() const ; void set_feature_separator(std::string ch); + void set_slot_feature_separator(std::string ch); int get_feature_of_nodes(int gpu_id, uint64_t* d_walk, uint64_t* d_offset, uint32_t size, - int slot_num); + int slot_num, + int* d_slot_feature_num_map, + int fea_num_per_node); void release_graph(); void release_graph_edge(); @@ -157,6 +163,7 @@ class GraphGpuWrapper { std::vector>> d_graph_all_type_total_keys_; std::vector> h_graph_all_type_keys_len_; + std::string slot_feature_separator_ = std::string(" "); }; #endif } // namespace framework diff --git a/paddle/fluid/pybind/fleet_py.cc b/paddle/fluid/pybind/fleet_py.cc index db2ab38530a2c..1c1d5a5269f30 100644 --- a/paddle/fluid/pybind/fleet_py.cc +++ b/paddle/fluid/pybind/fleet_py.cc @@ -368,6 +368,7 @@ void BindGraphGpuWrapper(py::module* m) { &GraphGpuWrapper::graph_neighbor_sample)) .def("set_device", &GraphGpuWrapper::set_device) .def("set_feature_separator", &GraphGpuWrapper::set_feature_separator) + .def("set_slot_feature_separator", &GraphGpuWrapper::set_slot_feature_separator) .def("init_service", &GraphGpuWrapper::init_service) .def("set_up_types", &GraphGpuWrapper::set_up_types) .def("query_node_list", &GraphGpuWrapper::query_node_list)