Skip to content

Commit

Permalink
Add debug log (PaddlePaddle#131)
Browse files Browse the repository at this point in the history
* Add debug log

* Add debug log

Co-authored-by: root <root@yq01-inf-hic-k8s-a100-ab2-0008.yq01.baidu.com>
  • Loading branch information
lxsbupt and root authored Oct 9, 2022
1 parent c6a07b2 commit 4380355
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 6 deletions.
33 changes: 29 additions & 4 deletions paddle/fluid/framework/data_feed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License. */
#include <sstream>
#include "cub/cub.cuh"
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h"
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_utils.h"
#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h"
#include "paddle/fluid/framework/fleet/heter_ps/hashtable.h"
#include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h"
Expand Down Expand Up @@ -874,11 +875,14 @@ int GraphDataGenerator::InsertTable(const unsigned long *d_keys,
}

void GraphDataGenerator::DoWalk() {
int device_id = place_.GetDeviceId();
debug_gpu_memory_info(device_id, "DoWalk start");
if (gpu_graph_training_) {
FillWalkBuf();
} else {
FillInferBuf();
}
debug_gpu_memory_info(device_id, "DoWalk end");
}

void GraphDataGenerator::clear_gpu_mem() {
Expand Down Expand Up @@ -967,6 +971,7 @@ int GraphDataGenerator::FillWalkBuf() {
cudaMemsetAsync(walk, 0, buf_size_ * sizeof(uint64_t), sample_stream_);
// cudaMemsetAsync(
// len_per_row, 0, once_max_sample_keynum * sizeof(int), sample_stream_);
int sample_times = 0;
int i = 0;
total_row_ = 0;

Expand All @@ -980,6 +985,7 @@ int GraphDataGenerator::FillWalkBuf() {
size_t node_type_len = first_node_type.size();
int remain_size =
buf_size_ - walk_degree_ * once_sample_startid_len_ * walk_len_;
int total_samples = 0;

while (i <= remain_size) {
int cur_node_idx = cursor % node_type_len;
Expand Down Expand Up @@ -1028,6 +1034,7 @@ int GraphDataGenerator::FillWalkBuf() {
int step = 1;
VLOG(2) << "sample edge type: " << path[0] << " step: " << 1;
jump_rows_ = sample_res.total_sample_size;
total_samples += sample_res.total_sample_size;
VLOG(2) << "i = " << i << " start = " << start << " tmp_len = " << tmp_len
<< " cursor = " << node_type << " cur_node_idx = " << cur_node_idx
<< " jump row: " << jump_rows_;
Expand Down Expand Up @@ -1066,11 +1073,16 @@ int GraphDataGenerator::FillWalkBuf() {
VLOG(2) << "h_walk[" << xx << "]: " << h_walk[xx];
}
}

VLOG(2) << "sample, step=" << step << " sample_keys=" << tmp_len
<< " sample_res_len=" << sample_res.total_sample_size;

/////////
step++;
size_t path_len = path.size();
for (; step < walk_len_; step++) {
if (sample_res.total_sample_size == 0) {
VLOG(2) << "sample finish, step=" << step;
break;
}
auto sample_key_mem = sample_res.actual_val_mem;
Expand All @@ -1085,9 +1097,8 @@ int GraphDataGenerator::FillWalkBuf() {
sample_res.total_sample_size);
int sample_key_len = sample_res.total_sample_size;
sample_res = gpu_graph_ptr->graph_neighbor_sample_v3(q, false);
total_samples += sample_res.total_sample_size;
if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) {
// table_->insert(sample_res.actual_val, sample_res.total_sample_size,
// d_uniq_node_num, sample_stream_);
if (InsertTable(sample_res.actual_val, sample_res.total_sample_size) !=
0) {
VLOG(2) << "in step: " << step << ", table is full";
Expand All @@ -1109,13 +1120,17 @@ int GraphDataGenerator::FillWalkBuf() {
VLOG(2) << "h_walk[" << xx << "]: " << h_walk[xx];
}
}

VLOG(2) << "sample, step=" << step << " sample_keys=" << sample_key_len
<< " sample_res_len=" << sample_res.total_sample_size;
}
// 此时更新全局采样状态
if (update == true) {
node_type_start[node_type] = tmp_len + start;
i += jump_rows_ * walk_len_;
total_row_ += jump_rows_;
cursor += 1;
sample_times++;
} else {
VLOG(2) << "table is full, not update stat!";
break;
Expand All @@ -1133,7 +1148,6 @@ int GraphDataGenerator::FillWalkBuf() {
thrust::device_pointer_cast(d_random_row),
engine);

VLOG(2) << "FillWalkBuf: " << total_row_;
cudaStreamSynchronize(sample_stream_);
shuffle_seed_ = engine();

Expand Down Expand Up @@ -1197,7 +1211,6 @@ int GraphDataGenerator::FillWalkBuf() {
cudaStreamSynchronize(sample_stream_);

host_vec_.resize(h_uniq_node_num);
VLOG(0) << "uniq node num: " << h_uniq_node_num;
cudaMemcpyAsync(host_vec_.data(),
d_uniq_node_ptr,
sizeof(uint64_t) * h_uniq_node_num,
Expand Down Expand Up @@ -1320,6 +1333,13 @@ int GraphDataGenerator::FillWalkBuf() {
sample_stream_);
cudaStreamSynchronize(sample_stream_);
}

VLOG(0) << "sample_times:" << sample_times
<< ", d_walk_size:" << buf_size_
<< ", d_walk_offset:" << i
<< ", total_rows:" << total_row_
<< ", total_samples:" << total_samples
<< ", h_uniq_node_num:" << h_uniq_node_num;
}
return total_row_ != 0;
}
Expand All @@ -1333,6 +1353,7 @@ void GraphDataGenerator::AllocResource(int thread_id,
gpuid_ = gpu_graph_ptr->device_id_mapping[thread_id];
thread_id_ = thread_id;
place_ = platform::CUDAPlace(gpuid_);
debug_gpu_memory_info(gpuid_, "AllocResource start");

platform::CUDADeviceGuard guard(gpuid_);
if (FLAGS_gpugraph_storage_mode != GpuGraphStorageMode::WHOLE_HBM) {
Expand Down Expand Up @@ -1443,6 +1464,8 @@ void GraphDataGenerator::AllocResource(int thread_id,
memory::AllocShared(place_, slot_num_ * sizeof(uint64_t *));

cudaStreamSynchronize(sample_stream_);

debug_gpu_memory_info(gpuid_, "AllocResource end");
}

void GraphDataGenerator::SetConfig(
Expand Down Expand Up @@ -1475,7 +1498,9 @@ void GraphDataGenerator::SetConfig(
std::string first_node_type = graph_config.first_node_type();
std::string meta_path = graph_config.meta_path();
auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
debug_gpu_memory_info("init_conf start");
gpu_graph_ptr->init_conf(first_node_type, meta_path);
debug_gpu_memory_info("init_conf end");
};

} // namespace framework
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/fleet/ps_gpu_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr<HeterContext> gpu_task) {
<< " dim index: " << j << " contains feasign nums: "
<< gpu_task->device_dim_ptr_[i][j].size();
}
VLOG(1) << i << " card with dynamic mf contains feasign nums total: "
VLOG(0) << i << " card with dynamic mf contains feasign nums total: "
<< feature_keys_count[i];
size_max = std::max(size_max, feature_keys_count[i]);
}
Expand Down
31 changes: 30 additions & 1 deletion paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@ AutoGrowthBestFitAllocator::AutoGrowthBestFitAllocator(
: underlying_allocator_(underlying_allocator),
alignment_(alignment),
chunk_size_(std::max(AlignedSize(chunk_size, alignment), alignment)),
allow_free_idle_chunk_(allow_free_idle_chunk) {}
allow_free_idle_chunk_(allow_free_idle_chunk) {
total_alloc_times_ = 0;
total_alloc_size_ = 0;
total_free_times_ = 0;
total_free_size_ = 0;
}

phi::Allocation *AutoGrowthBestFitAllocator::AllocateImpl(
size_t unaligned_size) {
Expand Down Expand Up @@ -112,6 +117,8 @@ phi::Allocation *AutoGrowthBestFitAllocator::AllocateImpl(
VLOG(2) << "Not found and reallocate " << realloc_size << "("
<< static_cast<void *>(p) << "), and remaining " << remaining_size;
}
++total_alloc_times_;
total_alloc_size_ += size;
VLOG(10) << "Alloc " << block_it->size_ << " bytes, ptr = " << block_it->ptr_;
return new BlockAllocation(block_it);
}
Expand All @@ -126,6 +133,9 @@ void AutoGrowthBestFitAllocator::FreeImpl(phi::Allocation *allocation) {
auto block_it = static_cast<BlockAllocation *>(allocation)->block_it_;
auto &blocks = block_it->chunk_->blocks_;

total_free_times_ += 1;
total_free_size_ += block_it->size_;

block_it->is_free_ = true;

if (block_it != blocks.begin()) {
Expand Down Expand Up @@ -176,9 +186,28 @@ uint64_t AutoGrowthBestFitAllocator::FreeIdleChunks() {
++chunk_it;
}
}

Trace();
return bytes;
}

void AutoGrowthBestFitAllocator::Trace() const {
size_t cur_idle_bytes = 0;
auto it = free_blocks_.begin();
for (; it != free_blocks_.end(); ++it) {
cur_idle_bytes += it->second->size_;
}

VLOG(0) << "alloc:" << total_alloc_size_ / double(1024*1024)
<< "m free:" << total_free_size_ / double(1024*1024)
<< "m busy:" << (total_alloc_size_ - total_free_size_) / double(1024*1024)
<< "m idle:" << cur_idle_bytes / double(1024*1024)
<< "m alloc_times:" << total_alloc_times_
<< " free_times:" << total_free_times_
<< " free_blocks_num:" << free_blocks_.size()
<< " curr_chunks_num:" << chunks_.size();
}

} // namespace allocation
} // namespace memory
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class AutoGrowthBestFitAllocator : public Allocator {

private:
uint64_t FreeIdleChunks();
void Trace() const;

template <typename T>
using List = std::list<T>;
Expand Down Expand Up @@ -93,6 +94,12 @@ class AutoGrowthBestFitAllocator : public Allocator {
size_t chunk_size_;
bool allow_free_idle_chunk_;

// stat info
size_t total_alloc_times_;
size_t total_alloc_size_;
size_t total_free_times_;
size_t total_free_size_;

SpinLock spinlock_;
};

Expand Down

0 comments on commit 4380355

Please sign in to comment.