Skip to content

Commit

Permalink
[GPUGraph] graph sample v2 (PaddlePaddle#87)
Browse files Browse the repository at this point in the history
* change load node and edge from local to cpu (PaddlePaddle#83)

* change load node and edge

* remove useless code

Co-authored-by: root <root@yq01-inf-hic-k8s-a100-ab2-0009.yq01.baidu.com>

* extract pull sparse as single stage(PaddlePaddle#85)

Co-authored-by: yangjunchao <yangjunchao@baidu.com>

* support ssdsparsetable;test=develop (PaddlePaddle#81)

* graph sample v2

* remove log

Co-authored-by: miaoli06 <106585574+miaoli06@users.noreply.github.com>
Co-authored-by: root <root@yq01-inf-hic-k8s-a100-ab2-0009.yq01.baidu.com>
Co-authored-by: chao9527 <33347532+chao9527@users.noreply.github.com>
Co-authored-by: yangjunchao <yangjunchao@baidu.com>
Co-authored-by: danleifeng <52735331+danleifeng@users.noreply.github.com>
  • Loading branch information
6 people authored Aug 22, 2022
1 parent a26eaa5 commit 5529e61
Show file tree
Hide file tree
Showing 10 changed files with 394 additions and 250 deletions.
5 changes: 5 additions & 0 deletions paddle/fluid/distributed/ps/table/ssd_sparse_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ int32_t SSDSparseTable::Initialize() {
MemorySparseTable::Initialize();
_db = paddle::distributed::RocksDBHandler::GetInstance();
_db->initialize(FLAGS_rocksdb_path, _real_local_shard_num);
VLOG(0) << "initalize SSDSparseTable succ";
return 0;
}

Expand Down Expand Up @@ -549,7 +550,11 @@ int32_t SSDSparseTable::Save(const std::string& path,
std::string table_path = TableDir(path);
_afs_client.remove(paddle::string::format_string(
"%s/part-%03d-*", table_path.c_str(), _shard_idx));
#ifdef PADDLE_WITH_GPU_GRAPH
int thread_num = _real_local_shard_num;
#else
int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20;
#endif

// std::atomic<uint32_t> feasign_size;
std::atomic<uint32_t> feasign_size_all{0};
Expand Down
6 changes: 5 additions & 1 deletion paddle/fluid/framework/data_feed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2118,6 +2118,10 @@ void SlotRecordInMemoryDataFeed::Init(const DataFeedDesc& data_feed_desc) {
}
}

void SlotRecordInMemoryDataFeed::InitGraphResource() {
gpu_graph_data_generator_.AllocResource(thread_id_, feed_vec_);
}

void SlotRecordInMemoryDataFeed::LoadIntoMemory() {
VLOG(3) << "SlotRecord LoadIntoMemory() begin, thread_id=" << thread_id_;
if (!so_parser_name_.empty()) {
Expand Down Expand Up @@ -2654,7 +2658,7 @@ bool SlotRecordInMemoryDataFeed::Start() {
pack_ = BatchGpuPackMgr().get(this->GetPlace(), used_slots_info_);
#endif
#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS)
gpu_graph_data_generator_.AllocResource(this->place_, feed_vec_);
gpu_graph_data_generator_.SetFeedVec(feed_vec_);
#endif
return true;
}
Expand Down
368 changes: 194 additions & 174 deletions paddle/fluid/framework/data_feed.cu

Large diffs are not rendered by default.

17 changes: 7 additions & 10 deletions paddle/fluid/framework/data_feed.h
Original file line number Diff line number Diff line change
Expand Up @@ -893,10 +893,10 @@ struct BufState {
class GraphDataGenerator {
public:
GraphDataGenerator(){};
virtual ~GraphDataGenerator(){};
virtual ~GraphDataGenerator();
void SetConfig(const paddle::framework::DataFeedDesc& data_feed_desc);
void AllocResource(const paddle::platform::Place& place,
std::vector<LoDTensor*> feed_vec);
void AllocResource(int thread_id, std::vector<LoDTensor*> feed_vec);
void SetFeedVec(std::vector<LoDTensor*> feed_vec);
int AcquireInstance(BufState* state);
int GenerateBatch();
int FillWalkBuf(std::shared_ptr<phi::Allocation> d_walk);
Expand Down Expand Up @@ -933,7 +933,8 @@ class GraphDataGenerator {
int64_t* id_tensor_ptr_;
int64_t* show_tensor_ptr_;
int64_t* clk_tensor_ptr_;
cudaStream_t stream_;
cudaStream_t train_stream_;
cudaStream_t sample_stream_;
paddle::platform::Place place_;
std::vector<LoDTensor*> feed_vec_;
std::vector<size_t> offset_;
Expand All @@ -951,10 +952,6 @@ class GraphDataGenerator {
std::shared_ptr<phi::Allocation> d_sample_keys_;
int sample_keys_len_;

std::set<int> finish_node_type_;
std::unordered_map<int, size_t> node_type_start_;
std::vector<int> infer_node_type_start_;

std::shared_ptr<phi::Allocation> d_ins_buf_;
std::shared_ptr<phi::Allocation> d_feature_buf_;
std::shared_ptr<phi::Allocation> d_pair_num_;
Expand All @@ -970,8 +967,6 @@ class GraphDataGenerator {
int slot_num_;
int shuffle_seed_;
int debug_mode_;
std::vector<int> first_node_type_;
std::vector<std::vector<int>> meta_path_;
bool gpu_graph_training_;
};

Expand Down Expand Up @@ -1037,6 +1032,7 @@ class DataFeed {
virtual void SetParseLogKey(bool parse_logkey) {}
virtual void SetEnablePvMerge(bool enable_pv_merge) {}
virtual void SetCurrentPhase(int current_phase) {}
virtual void InitGraphResource() {}
virtual void SetDeviceKeys(std::vector<uint64_t>* device_keys, int type) {
#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS)
gpu_graph_data_generator_.SetDeviceKeys(device_keys, type);
Expand Down Expand Up @@ -1637,6 +1633,7 @@ class SlotRecordInMemoryDataFeed : public InMemoryDataFeed<SlotRecord> {
// CustomParser* parser) {}
virtual void PutToFeedVec(const std::vector<SlotRecord>& ins_vec) {}

virtual void InitGraphResource(void);
virtual void LoadIntoMemoryByCommand(void);
virtual void LoadIntoMemoryByLib(void);
virtual void LoadIntoMemoryByLine(void);
Expand Down
76 changes: 31 additions & 45 deletions paddle/fluid/framework/data_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

USE_INT_STAT(STAT_total_feasign_num_in_mem);
DECLARE_bool(graph_get_neighbor_id);
DECLARE_int32(gpugraph_storage_mode);

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -473,60 +474,44 @@ void DatasetImpl<T>::LoadIntoMemory() {
auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
auto node_to_id = gpu_graph_ptr->feature_to_id;
auto edge_to_id = gpu_graph_ptr->edge_to_id;
graph_all_type_total_keys_.resize(node_to_id.size());
int cnt = 0;
// set sample start node
for (auto& iter : node_to_id) {
int node_idx = iter.second;
std::vector<std::vector<uint64_t>> gpu_graph_device_keys;
gpu_graph_ptr->get_all_id(
1, node_idx, thread_num_, &gpu_graph_device_keys);
auto& type_total_key = graph_all_type_total_keys_[cnt];
type_total_key.resize(thread_num_);
for (size_t i = 0; i < gpu_graph_device_keys.size(); i++) {
VLOG(2) << "node type: " << node_idx << ", gpu_graph_device_keys[" << i
<< "] = " << gpu_graph_device_keys[i].size();
for (size_t j = 0; j < gpu_graph_device_keys[i].size(); j++) {
type_total_key[i].push_back(gpu_graph_device_keys[i][j]);
}
}

for (size_t i = 0; i < readers_.size(); i++) {
readers_[i]->SetDeviceKeys(&type_total_key[i], node_idx);
readers_[i]->SetGpuGraphMode(gpu_graph_mode_);
}
cnt++;
}

// add node embedding id
std::vector<std::vector<uint64_t>> gpu_graph_device_keys;
gpu_graph_ptr->get_node_embedding_ids(thread_num_, &gpu_graph_device_keys);
for (size_t i = 0; i < gpu_graph_device_keys.size(); i++) {
for (size_t j = 0; j < gpu_graph_device_keys[i].size(); j++) {
gpu_graph_total_keys_.push_back(gpu_graph_device_keys[i][j]);
}
for (size_t i = 0; i < readers_.size(); i++) {
readers_[i]->SetGpuGraphMode(gpu_graph_mode_);
}

// add feature embedding id
VLOG(2) << "begin add feature_id into gpu_graph_total_keys_ size["
<< gpu_graph_total_keys_.size() << "]";
for (auto& iter : node_to_id) {
if (FLAGS_gpugraph_storage_mode == GpuGraphStorageMode::HBM) {
// add node embedding id
std::vector<std::vector<uint64_t>> gpu_graph_device_keys;
int node_idx = iter.second;
gpu_graph_ptr->get_all_feature_ids(
1, node_idx, thread_num_, &gpu_graph_device_keys);
gpu_graph_ptr->get_node_embedding_ids(thread_num_,
&gpu_graph_device_keys);
for (size_t i = 0; i < gpu_graph_device_keys.size(); i++) {
VLOG(2) << "begin node type: " << node_idx << ", gpu_graph_device_keys["
<< i << "] = " << gpu_graph_device_keys[i].size();
for (size_t j = 0; j < gpu_graph_device_keys[i].size(); j++) {
gpu_graph_total_keys_.push_back(gpu_graph_device_keys[i][j]);
}
VLOG(2) << "end node type: " << node_idx << ", gpu_graph_device_keys["
<< i << "] = " << gpu_graph_device_keys[i].size();
}

// add feature embedding id
VLOG(2) << "begin add feature_id into gpu_graph_total_keys_ size["
<< gpu_graph_total_keys_.size() << "]";
for (auto& iter : node_to_id) {
std::vector<std::vector<uint64_t>> gpu_graph_device_keys;
int node_idx = iter.second;
gpu_graph_ptr->get_all_feature_ids(
1, node_idx, thread_num_, &gpu_graph_device_keys);
for (size_t i = 0; i < gpu_graph_device_keys.size(); i++) {
VLOG(2) << "begin node type: " << node_idx
<< ", gpu_graph_device_keys[" << i
<< "] = " << gpu_graph_device_keys[i].size();
for (size_t j = 0; j < gpu_graph_device_keys[i].size(); j++) {
gpu_graph_total_keys_.push_back(gpu_graph_device_keys[i][j]);
}
VLOG(2) << "end node type: " << node_idx << ", gpu_graph_device_keys["
<< i << "] = " << gpu_graph_device_keys[i].size();
}
}
VLOG(2) << "end add feature_id into gpu_graph_total_keys_ size["
<< gpu_graph_total_keys_.size() << "]";
} else if (FLAGS_gpugraph_storage_mode == GpuGraphStorageMode::CPU) {
}
VLOG(2) << "end add feature_id into gpu_graph_total_keys_ size["
<< gpu_graph_total_keys_.size() << "]";
#endif
} else {
for (int64_t i = 0; i < thread_num_; ++i) {
Expand Down Expand Up @@ -1780,6 +1765,7 @@ void SlotRecordDataset::CreateReaders() {
readers_[i]->SetParseLogKey(parse_logkey_);
readers_[i]->SetEnablePvMerge(enable_pv_merge_);
readers_[i]->SetCurrentPhase(current_phase_);
readers_[i]->InitGraphResource();
if (input_channel_ != nullptr) {
readers_[i]->SetInputChannel(input_channel_.get());
}
Expand Down
64 changes: 50 additions & 14 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -698,8 +698,14 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2(
int h_left[total_gpu]; // NOLINT
int h_right[total_gpu]; // NOLINT

auto d_left = memory::Alloc(place, total_gpu * sizeof(int));
auto d_right = memory::Alloc(place, total_gpu * sizeof(int));
auto d_left =
memory::Alloc(place,
total_gpu * sizeof(int),
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
auto d_right =
memory::Alloc(place,
total_gpu * sizeof(int),
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
int* d_left_ptr = reinterpret_cast<int*>(d_left->ptr());
int* d_right_ptr = reinterpret_cast<int*>(d_right->ptr());
int default_value = 0;
Expand All @@ -710,15 +716,26 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2(
CUDA_CHECK(cudaMemsetAsync(d_left_ptr, -1, total_gpu * sizeof(int), stream));
CUDA_CHECK(cudaMemsetAsync(d_right_ptr, -1, total_gpu * sizeof(int), stream));
//
auto d_idx = memory::Alloc(place, len * sizeof(int));
auto d_idx =
memory::Alloc(place,
len * sizeof(int),
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
int* d_idx_ptr = reinterpret_cast<int*>(d_idx->ptr());

auto d_shard_keys = memory::Alloc(place, len * sizeof(uint64_t));
auto d_shard_keys =
memory::Alloc(place,
len * sizeof(uint64_t),
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
uint64_t* d_shard_keys_ptr = reinterpret_cast<uint64_t*>(d_shard_keys->ptr());
auto d_shard_vals =
memory::Alloc(place, sample_size * len * sizeof(uint64_t));
memory::Alloc(place,
sample_size * len * sizeof(uint64_t),
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
uint64_t* d_shard_vals_ptr = reinterpret_cast<uint64_t*>(d_shard_vals->ptr());
auto d_shard_actual_sample_size = memory::Alloc(place, len * sizeof(int));
auto d_shard_actual_sample_size =
memory::Alloc(place,
len * sizeof(int),
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
int* d_shard_actual_sample_size_ptr =
reinterpret_cast<int*>(d_shard_actual_sample_size->ptr());

Expand Down Expand Up @@ -919,8 +936,10 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2(
int total_sample_size = thrust::reduce(t_actual_sample_size.begin(),
t_actual_sample_size.end());
result.actual_val_mem =
memory::AllocShared(place, total_sample_size * sizeof(uint64_t));
result.actual_val_mem = memory::AllocShared(
place,
total_sample_size * sizeof(uint64_t),
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
result.actual_val = (uint64_t*)(result.actual_val_mem)->ptr();
result.set_total_sample_size(total_sample_size);
Expand Down Expand Up @@ -1011,23 +1030,40 @@ int GpuPsGraphTable::get_feature_of_nodes(int gpu_id,
int total_gpu = resource_->total_device();
auto stream = resource_->local_stream(gpu_id, 0);
auto d_left = memory::Alloc(place, total_gpu * sizeof(int));
auto d_right = memory::Alloc(place, total_gpu * sizeof(int));
auto d_left =
memory::Alloc(place,
total_gpu * sizeof(int),
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
auto d_right =
memory::Alloc(place,
total_gpu * sizeof(int),
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
int* d_left_ptr = reinterpret_cast<int*>(d_left->ptr());
int* d_right_ptr = reinterpret_cast<int*>(d_right->ptr());
CUDA_CHECK(cudaMemsetAsync(d_left_ptr, -1, total_gpu * sizeof(int), stream));
CUDA_CHECK(cudaMemsetAsync(d_right_ptr, -1, total_gpu * sizeof(int), stream));
//
auto d_idx = memory::Alloc(place, node_num * sizeof(int));
auto d_idx =
memory::Alloc(place,
node_num * sizeof(int),
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
int* d_idx_ptr = reinterpret_cast<int*>(d_idx->ptr());
auto d_shard_keys = memory::Alloc(place, node_num * sizeof(uint64_t));
auto d_shard_keys =
memory::Alloc(place,
node_num * sizeof(uint64_t),
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
uint64_t* d_shard_keys_ptr = reinterpret_cast<uint64_t*>(d_shard_keys->ptr());
auto d_shard_vals =
memory::Alloc(place, slot_num * node_num * sizeof(uint64_t));
memory::Alloc(place,
slot_num * node_num * sizeof(uint64_t),
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
uint64_t* d_shard_vals_ptr = reinterpret_cast<uint64_t*>(d_shard_vals->ptr());
auto d_shard_actual_size = memory::Alloc(place, node_num * sizeof(int));
auto d_shard_actual_size =
memory::Alloc(place,
node_num * sizeof(int),
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
int* d_shard_actual_size_ptr =
reinterpret_cast<int*>(d_shard_actual_size->ptr());
Expand Down
72 changes: 69 additions & 3 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,72 @@ void GraphGpuWrapper::set_device(std::vector<int> ids) {
}
}

void GraphGpuWrapper::init_conf(const std::string &first_node_type,
const std::string &meta_path) {
static std::mutex mutex;
{
std::lock_guard<std::mutex> lock(mutex);
if (conf_initialized_) {
return;
}
VLOG(2) << "init path config";
conf_initialized_ = true;
auto node_types =
paddle::string::split_string<std::string>(first_node_type, ";");
VLOG(2) << "node_types: " << first_node_type;
for (auto &type : node_types) {
auto iter = feature_to_id.find(type);
PADDLE_ENFORCE_NE(iter,
feature_to_id.end(),
platform::errors::NotFound(
"(%s) is not found in feature_to_id.", type));
VLOG(2) << "feature_to_id[" << type << "] = " << iter->second;
first_node_type_.push_back(iter->second);
}
meta_path_.resize(first_node_type_.size());
auto meta_paths = paddle::string::split_string<std::string>(meta_path, ";");

for (size_t i = 0; i < meta_paths.size(); i++) {
auto path = meta_paths[i];
auto nodes = paddle::string::split_string<std::string>(path, "-");
for (auto &node : nodes) {
auto iter = edge_to_id.find(node);
PADDLE_ENFORCE_NE(iter,
edge_to_id.end(),
platform::errors::NotFound(
"(%s) is not found in edge_to_id.", node));
VLOG(2) << "edge_to_id[" << node << "] = " << iter->second;
meta_path_[i].push_back(iter->second);
}
}
int max_dev_id = 0;
for (size_t i = 0; i < device_id_mapping.size(); i++) {
if (device_id_mapping[i] > max_dev_id) {
max_dev_id = device_id_mapping[i];
}
}
finish_node_type_.resize(max_dev_id + 1);
node_type_start_.resize(max_dev_id + 1);
infer_node_type_start_.resize(max_dev_id + 1);
for (size_t i = 0; i < device_id_mapping.size(); i++) {
int dev_id = device_id_mapping[i];
auto &node_type_start = node_type_start_[i];
auto &infer_node_type_start = infer_node_type_start_[i];
auto &finish_node_type = finish_node_type_[i];
finish_node_type.clear();
// for (auto& kv : feature_to_id) {
// node_type_start[kv.second] = 0;
// infer_node_type_start[kv.second] = 0;
// }
for (auto &type : node_types) {
auto iter = feature_to_id.find(type);
node_type_start[iter->second] = 0;
infer_node_type_start[iter->second] = 0;
}
}
}
}

int GraphGpuWrapper::get_all_id(int type,
int slice_num,
std::vector<std::vector<uint64_t>> *output) {
Expand Down Expand Up @@ -160,9 +226,9 @@ void GraphGpuWrapper::load_node_and_edge(std::string etype2files,
std::string graph_data_local_path,
int part_num,
bool reverse) {
((GpuPsGraphTable *)graph_table)
->cpu_graph_table_->load_node_and_edge_file(
etype2files, ntype2files, graph_data_local_path, part_num, reverse);
((GpuPsGraphTable *)graph_table)
->cpu_graph_table_->load_node_and_edge_file(
etype2files, ntype2files, graph_data_local_path, part_num, reverse);
}

void GraphGpuWrapper::add_table_feat_conf(std::string table_name,
Expand Down
Loading

0 comments on commit 5529e61

Please sign in to comment.