Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add check hbm #5

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cmake/cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ add_definitions("-DCUDA_VERSION_MINOR=\"${CUDA_VERSION_MINOR}\"")
add_definitions("-DCUDA_TOOLKIT_ROOT_DIR=\"${CUDA_TOOLKIT_ROOT_DIR}\"")

# setting nvcc arch flags
select_nvcc_arch_flags(NVCC_FLAGS_EXTRA)
#select_nvcc_arch_flags(NVCC_FLAGS_EXTRA)
set(NVCC_FLAGS_EXTRA "-gencode arch=compute_70,code=sm_70 -gencode arch=compute_80,code=sm_80")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${NVCC_FLAGS_EXTRA}")
message(STATUS "NVCC_FLAGS_EXTRA: ${NVCC_FLAGS_EXTRA}")

Expand Down
4 changes: 3 additions & 1 deletion cmake/external/pslib.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ ExternalProject_Add(
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${PSLIB_PREFIX_DIR}
DOWNLOAD_DIR ${PSLIB_DOWNLOAD_DIR}
DOWNLOAD_COMMAND wget --no-check-certificate ${PSLIB_URL} -c -q -O ${PSLIB_NAME}.tar.gz
#DOWNLOAD_COMMAND wget --no-check-certificate ${PSLIB_URL} -c -q -O ${PSLIB_NAME}.tar.gz
#DOWNLOAD_COMMAND cp /yaoxuefeng/repos/pslib_rdma_tmp/baidu/paddlepaddle/pslib/pslib.tar.gz .
DOWNLOAD_COMMAND cp /zhangminxu/so_debug/baidu/paddlepaddle/pslib_2/pslib.tar.gz ./
&& tar zxvf ${PSLIB_NAME}.tar.gz
DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND ""
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/fleet/heter_ps/hashtable.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class HashTable {
Sgd sgd, gpuStream_t stream);

template <typename Sgd>
void update(const KeyType* d_keys, const char* d_grads, size_t len, Sgd sgd,
void update(int gpu_num, const KeyType* d_keys, const char* d_grads, size_t len, Sgd sgd,
gpuStream_t stream);

int size() { return container_->size(); }
Expand Down
13 changes: 8 additions & 5 deletions paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ __global__ void update_kernel(Table* table,
}

template <typename Table, typename Sgd>
__global__ void dy_mf_update_kernel(Table* table,
__global__ void dy_mf_update_kernel(int gpu_num, Table* table,
const typename Table::key_type* const keys,
const char* const grads, size_t len,
Sgd sgd, size_t grad_value_size) {
Expand All @@ -220,9 +220,12 @@ __global__ void dy_mf_update_kernel(Table* table,
auto it = table->find(keys[i]);
if (it != table->end()) {
FeaturePushValue* cur = (FeaturePushValue*)(grads + i * grad_value_size);
sgd.dy_mf_update_value((it.getter())->second, *cur);
// sgd.dy_mf_update_value((it.getter())->second, *cur);
} else {
if(keys[i] != 0) printf("push miss key: %d", keys[i]);
if (keys[i] != 0) {
// get device id
printf("push miss key: %llu %d\n", keys[i], gpu_num);
}
}
}
}
Expand Down Expand Up @@ -387,15 +390,15 @@ void HashTable<KeyType, ValType>::update(const KeyType* d_keys,

template <typename KeyType, typename ValType>
template <typename Sgd>
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
void HashTable<KeyType, ValType>::update(int gpu_num, const KeyType* d_keys,
const char* d_grads, size_t len,
Sgd sgd, gpuStream_t stream) {
if (len == 0) {
return;
}
const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;

dy_mf_update_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
dy_mf_update_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(gpu_num,
container_, d_keys, d_grads, len, sgd, push_grad_value_size_);
}

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/fleet/heter_ps/heter_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ class HeterComm {

void init_path();

void create_storage(int start_index, int end_index, int keylen, int vallen);
void create_storage(int start_index, int end_index, size_t keylen, size_t vallen);
void destroy_storage(int start_index, int end_index);
void walk_to_dest(int start_index, int gpu_num, int* h_left, int* h_right,
KeyType* src_key, GradType* src_val);
Expand Down
Loading