Skip to content

Commit

Permalink
support slot_feature with different length (PaddlePaddle#124)
Browse files Browse the repository at this point in the history
Co-authored-by: root <root@yq01-inf-hic-k8s-a100-ab2-0009.yq01.baidu.com>
  • Loading branch information
huwei02 and root committed Nov 26, 2022
1 parent 373252d commit 520e680
Show file tree
Hide file tree
Showing 9 changed files with 175 additions and 65 deletions.
45 changes: 35 additions & 10 deletions paddle/fluid/distributed/ps/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
std::vector<uint64_t> node_id_array[task_pool_size_];
std::vector<paddle::framework::GpuPsFeaInfo>
node_fea_info_array[task_pool_size_];
slot_feature_num_map_.resize(slot_num);
for (int k = 0; k < slot_num; ++k) {
slot_feature_num_map_[k] = 0;
}

for (size_t i = 0; i < bags.size(); i++) {
if (bags[i].size() > 0) {
tasks.push_back(_shards_task_pool[i]->enqueue([&, i, this]() -> int {
Expand All @@ -91,13 +96,17 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
int total_feature_size = 0;
for (int k = 0; k < slot_num; ++k) {
v->get_feature_ids(k, &feature_ids);
total_feature_size += feature_ids.size();
int feature_ids_size = feature_ids.size();
if (slot_feature_num_map_[k] < feature_ids_size) {
slot_feature_num_map_[k] = feature_ids_size;
}
total_feature_size += feature_ids_size;
if (!feature_ids.empty()) {
feature_array[i].insert(feature_array[i].end(),
feature_ids.begin(),
feature_ids.end());
slot_id_array[i].insert(
slot_id_array[i].end(), feature_ids.size(), k);
slot_id_array[i].end(), feature_ids_size, k);
}
}
x.feature_size = total_feature_size;
Expand All @@ -109,7 +118,14 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
}));
}
}
for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
for (int i = 0; i < (int)tasks.size(); i++) tasks[i].get();

std::stringstream ss;
for (int k = 0; k < slot_num; ++k) {
ss << slot_feature_num_map_[k] << " ";
}
VLOG(0) << "slot_feature_num_map: " << ss.str();

paddle::framework::GpuPsCommGraphFea res;
uint64_t tot_len = 0;
for (int i = 0; i < task_pool_size_; i++) {
Expand Down Expand Up @@ -1852,9 +1868,14 @@ int GraphTable::parse_feature(int idx,
// "")
thread_local std::vector<paddle::string::str_ptr> fields;
fields.clear();
const char c = feature_separator_.at(0);
char c = slot_feature_separator_.at(0);
paddle::string::split_string_ptr(feat_str, len, c, &fields);

thread_local std::vector<paddle::string::str_ptr> fea_fields;
fea_fields.clear();
c = feature_separator_.at(0);
paddle::string::split_string_ptr(fields[1].ptr, fields[1].len, c, &fea_fields);

std::string name = fields[0].to_string();
auto it = feat_id_map[idx].find(name);
if (it != feat_id_map[idx].end()) {
Expand All @@ -1865,26 +1886,26 @@ int GraphTable::parse_feature(int idx,
// string_vector_2_string(fields.begin() + 1, fields.end(), ' ',
// fea_ptr);
FeatureNode::parse_value_to_bytes<uint64_t>(
fields.begin() + 1, fields.end(), fea_ptr);
fea_fields.begin(), fea_fields.end(), fea_ptr);
return 0;
} else if (dtype == "string") {
string_vector_2_string(fields.begin() + 1, fields.end(), ' ', fea_ptr);
string_vector_2_string(fea_fields.begin(), fea_fields.end(), ' ', fea_ptr);
return 0;
} else if (dtype == "float32") {
FeatureNode::parse_value_to_bytes<float>(
fields.begin() + 1, fields.end(), fea_ptr);
fea_fields.begin(), fea_fields.end(), fea_ptr);
return 0;
} else if (dtype == "float64") {
FeatureNode::parse_value_to_bytes<double>(
fields.begin() + 1, fields.end(), fea_ptr);
fea_fields.begin(), fea_fields.end(), fea_ptr);
return 0;
} else if (dtype == "int32") {
FeatureNode::parse_value_to_bytes<int32_t>(
fields.begin() + 1, fields.end(), fea_ptr);
fea_fields.begin(), fea_fields.end(), fea_ptr);
return 0;
} else if (dtype == "int64") {
FeatureNode::parse_value_to_bytes<uint64_t>(
fields.begin() + 1, fields.end(), fea_ptr);
fea_fields.begin(), fea_fields.end(), fea_ptr);
return 0;
}
} else {
Expand Down Expand Up @@ -2111,6 +2132,10 @@ void GraphTable::set_feature_separator(const std::string &ch) {
feature_separator_ = ch;
}

void GraphTable::set_slot_feature_separator(const std::string &ch) {
slot_feature_separator_ = ch;
}

int32_t GraphTable::get_server_index_by_id(uint64_t id) {
return id % shard_num / shard_num_per_server;
}
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/distributed/ps/table/common_graph_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,7 @@ class GraphTable : public Table {
int32_t make_complementary_graph(int idx, int64_t byte_size);
int32_t dump_edges_to_ssd(int idx);
int32_t get_partition_num(int idx) { return partitions[idx].size(); }
std::vector<int> slot_feature_num_map() const { return slot_feature_num_map_; }
std::vector<uint64_t> get_partition(int idx, int index) {
if (idx >= (int)partitions.size() || index >= (int)partitions[idx].size())
return std::vector<uint64_t>();
Expand All @@ -695,6 +696,7 @@ class GraphTable : public Table {
#endif
virtual int32_t add_comm_edge(int idx, uint64_t src_id, uint64_t dst_id);
virtual int32_t build_sampler(int idx, std::string sample_type = "random");
void set_slot_feature_separator(const std::string &ch);
void set_feature_separator(const std::string &ch);
std::vector<std::vector<GraphShard *>> edge_shards, feature_shards;
size_t shard_start, shard_end, server_num, shard_num_per_server, shard_num;
Expand Down Expand Up @@ -732,7 +734,9 @@ class GraphTable : public Table {
// std::shared_ptr<GraphSampler> graph_sampler;
// REGISTER_GRAPH_FRIEND_CLASS(2, CompleteGraphSampler, BasicBfsGraphSampler)
#endif
std::string slot_feature_separator_ = std::string(" ");
std::string feature_separator_ = std::string(" ");
std::vector<int> slot_feature_num_map_;
};

/*
Expand Down
95 changes: 67 additions & 28 deletions paddle/fluid/framework/data_feed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -341,22 +341,29 @@ __global__ void GraphFillSlotKernel(uint64_t *id_tensor,
uint64_t *feature_buf,
int len,
int total_ins,
int slot_num) {
int slot_num,
int* slot_feature_num_map,
int fea_num_per_node,
int* actual_slot_id_map,
int* fea_offset_map) {
CUDA_KERNEL_LOOP(idx, len) {
int slot_idx = idx / total_ins;
int fea_idx = idx / total_ins;
int ins_idx = idx % total_ins;
((uint64_t *)(id_tensor[slot_idx]))[ins_idx] = // NOLINT
feature_buf[ins_idx * slot_num + slot_idx];
int actual_slot_id = actual_slot_id_map[fea_idx];
int fea_offset = fea_offset_map[fea_idx];
((uint64_t *)(id_tensor[actual_slot_id]))[ins_idx * slot_feature_num_map[actual_slot_id] + fea_offset]
= feature_buf[ins_idx * fea_num_per_node + fea_idx];
}
}

__global__ void GraphFillSlotLodKernelOpt(uint64_t *id_tensor,
int len,
int total_ins) {
int total_ins,
int* slot_feature_num_map) {
CUDA_KERNEL_LOOP(idx, len) {
int slot_idx = idx / total_ins;
int ins_idx = idx % total_ins;
((uint64_t *)(id_tensor[slot_idx]))[ins_idx] = ins_idx; // NOLINT
((uint64_t *)(id_tensor[slot_idx]))[ins_idx] = ins_idx * slot_feature_num_map[slot_idx];
}
}

Expand Down Expand Up @@ -838,7 +845,7 @@ int GraphDataGenerator::GenerateBatch() {
if (slot_num_ > 0) {
for (int i = 0; i < slot_num_; ++i) {
slot_tensor_ptr_[i] = feed_vec_[3 + 2 * i]->mutable_data<int64_t>(
{total_instance, 1}, this->place_);
{total_instance * h_slot_feature_num_map_[i], 1}, this->place_);
slot_lod_tensor_ptr_[i] = feed_vec_[3 + 2 * i + 1]->mutable_data<int64_t>(
{total_instance + 1}, this->place_);
}
Expand Down Expand Up @@ -934,39 +941,43 @@ int GraphDataGenerator::GenerateBatch() {
ins_cursor,
total_instance * sizeof(uint64_t),
cudaMemcpyDeviceToHost);
uint64_t h_feature[total_instance * slot_num_]; // NOLINT
uint64_t h_feature[total_instance * fea_num_per_node_];
cudaMemcpy(h_feature,
feature_buf,
total_instance * slot_num_ * sizeof(uint64_t),
total_instance * fea_num_per_node_ * sizeof(uint64_t),
cudaMemcpyDeviceToHost);
for (int i = 0; i < total_instance; ++i) {
std::stringstream ss;
for (int j = 0; j < slot_num_; ++j) {
ss << h_feature[i * slot_num_ + j] << " ";
for (int j = 0; j < fea_num_per_node_; ++j) {
ss << h_feature[i * fea_num_per_node_ + j] << " ";
}
VLOG(2) << "aft FillFeatureBuf, gpu[" << gpuid_ << "] walk[" << i
<< "] = " << (uint64_t)h_walk[i] << " feature["
<< i * slot_num_ << ".." << (i + 1) * slot_num_
<< i * fea_num_per_node_ << ".." << (i + 1) * fea_num_per_node_
<< "] = " << ss.str();
}
}

GraphFillSlotKernel<<<GET_BLOCKS(total_instance * slot_num_),
GraphFillSlotKernel<<<GET_BLOCKS(total_instance * fea_num_per_node_),
CUDA_NUM_THREADS,
0,
stream_>>>(
(uint64_t *)d_slot_tensor_ptr_->ptr(), // NOLINT
feature_buf,
total_instance * slot_num_,
total_instance,
slot_num_);
stream_>>>((uint64_t *)d_slot_tensor_ptr_->ptr(),
feature_buf,
total_instance * fea_num_per_node_,
total_instance,
slot_num_,
(int*)d_slot_feature_num_map_->ptr(),
fea_num_per_node_,
(int*)d_actual_slot_id_map_->ptr(),
(int*)d_fea_offset_map_->ptr());
GraphFillSlotLodKernelOpt<<<GET_BLOCKS((total_instance + 1) * slot_num_),
CUDA_NUM_THREADS,
0,
stream_>>>(
(uint64_t *)d_slot_lod_tensor_ptr_->ptr(), // NOLINT
(total_instance + 1) * slot_num_,
total_instance + 1);
total_instance + 1,
(int*)d_slot_feature_num_map_->ptr());
} else {
for (int i = 0; i < slot_num_; ++i) {
int feature_buf_offset =
Expand Down Expand Up @@ -1010,7 +1021,7 @@ int GraphDataGenerator::GenerateBatch() {
if (!gpu_graph_training_) return 1;
ins_buf_pair_len_ -= total_instance / 2;
if (debug_mode_) {
uint64_t h_slot_tensor[slot_num_][total_instance];
uint64_t h_slot_tensor[fea_num_per_node_][total_instance];
uint64_t h_slot_lod_tensor[slot_num_][total_instance + 1];
for (int i = 0; i < slot_num_; ++i) {
cudaMemcpy(h_slot_tensor[i],
Expand Down Expand Up @@ -1211,7 +1222,8 @@ int GraphDataGenerator::FillFeatureBuf(uint64_t *d_walk,

auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
int ret = gpu_graph_ptr->get_feature_of_nodes(
gpuid_, d_walk, d_feature, key_num, slot_num_);
gpuid_, d_walk, d_feature, key_num, slot_num_,
(int*)d_slot_feature_num_map_->ptr(), fea_num_per_node_);
return ret;
}

Expand All @@ -1221,12 +1233,13 @@ int GraphDataGenerator::FillFeatureBuf(
platform::CUDADeviceGuard guard(gpuid_);

auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
int ret = gpu_graph_ptr->get_feature_of_nodes(
gpuid_,
(uint64_t *)d_walk->ptr(), // NOLINT
(uint64_t *)d_feature->ptr(), // NOLINT
buf_size_,
slot_num_);
int ret = gpu_graph_ptr->get_feature_of_nodes(gpuid_,
(uint64_t *)d_walk->ptr(),
(uint64_t *)d_feature->ptr(),
buf_size_,
slot_num_,
(int*)d_slot_feature_num_map_->ptr(),
fea_num_per_node_);
return ret;
}

Expand Down Expand Up @@ -1413,6 +1426,32 @@ void GraphDataGenerator::AllocResource(
slot_num_ = (feed_vec_.size() - 4 - samples_.size() * 5) / 2;
}

auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
h_slot_feature_num_map_ = gpu_graph_ptr->slot_feature_num_map();
fea_num_per_node_ = 0;
for (int i = 0; i < slot_num_; ++i) {
fea_num_per_node_ += h_slot_feature_num_map_[i];
}
std::vector<int> h_actual_slot_id_map, h_fea_offset_map;
h_actual_slot_id_map.resize(fea_num_per_node_);
h_fea_offset_map.resize(fea_num_per_node_);
for (int slot_id = 0, fea_idx = 0; slot_id < slot_num_; ++slot_id) {
for (int j = 0; j < h_slot_feature_num_map_[slot_id]; ++j, ++fea_idx) {
h_actual_slot_id_map[fea_idx] = slot_id;
h_fea_offset_map[fea_idx] = j;
}
}

d_slot_feature_num_map_ = memory::Alloc(place, slot_num_ * sizeof(int));
cudaMemcpy(d_slot_feature_num_map_->ptr(), h_slot_feature_num_map_.data(),
sizeof(int) * slot_num_, cudaMemcpyHostToDevice);
d_actual_slot_id_map_ = memory::Alloc(place, fea_num_per_node_ * sizeof(int));
cudaMemcpy(d_actual_slot_id_map_->ptr(), h_actual_slot_id_map.data(),
sizeof(int) * fea_num_per_node_, cudaMemcpyHostToDevice);
d_fea_offset_map_ = memory::Alloc(place, fea_num_per_node_ * sizeof(int));
cudaMemcpy(d_fea_offset_map_->ptr(), h_fea_offset_map.data(),
sizeof(int) * fea_num_per_node_, cudaMemcpyHostToDevice);

// d_device_keys_.resize(h_device_keys_.size());
VLOG(2) << "h_device_keys size: " << h_device_keys_.size();
infer_node_type_start_ = std::vector<int>(h_device_keys_.size(), 0);
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/framework/data_feed.h
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,9 @@ class GraphDataGenerator {
std::shared_ptr<phi::Allocation> d_feature_;
std::shared_ptr<phi::Allocation> d_len_per_row_;
std::shared_ptr<phi::Allocation> d_random_row_;
std::shared_ptr<phi::Allocation> d_slot_feature_num_map_;
std::shared_ptr<phi::Allocation> d_actual_slot_id_map_;
std::shared_ptr<phi::Allocation> d_fea_offset_map_;
//
std::vector<std::shared_ptr<phi::Allocation>> d_sampleidx2rows_;
int cur_sampleidx2row_;
Expand Down Expand Up @@ -987,6 +990,8 @@ class GraphDataGenerator {
BufState buf_state_;
int batch_size_;
int slot_num_;
std::vector<int> h_slot_feature_num_map_;
int fea_num_per_node_;
int shuffle_seed_;
int debug_mode_;
std::vector<int> first_node_type_;
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ class GpuPsGraphTable
std::vector<std::shared_ptr<phi::Allocation>> edge_type_graphs);
std::vector<std::shared_ptr<phi::Allocation>> get_edge_type_graph(int gpu_id, int edge_type_len);
int get_feature_of_nodes(
int gpu_id, uint64_t *d_walk, uint64_t *d_offset, int size, int slot_num);
int gpu_id, uint64_t *d_walk, uint64_t *d_offset, int size, int slot_num,
int* d_slot_feature_num_map, int fea_num_per_node);

NodeQueryResult query_node_list(int gpu_id,
int idx,
Expand Down
Loading

0 comments on commit 520e680

Please sign in to comment.