Skip to content

Commit

Permalink
Merge c0f57f6 into 9b5e2b9
Browse files Browse the repository at this point in the history
  • Loading branch information
galeselee authored Jul 20, 2021
2 parents 9b5e2b9 + c0f57f6 commit 7f2fe56
Show file tree
Hide file tree
Showing 30 changed files with 323 additions and 305 deletions.
77 changes: 11 additions & 66 deletions source/api_cc/src/DeepPot.cc
Original file line number Diff line number Diff line change
@@ -1,38 +1,11 @@
#include "DeepPot.h"
#include "AtomMap.h"
#include <stdexcept>
#include "device.h"

using namespace tensorflow;
using namespace deepmd;

#if GOOGLE_CUDA
#include "cuda_runtime.h"

#define cudaErrcheck(res) { cudaAssert((res), __FILE__, __LINE__); }
inline void cudaAssert(cudaError_t code, const char *file, int line, bool abort=true)
{
if (code != cudaSuccess)
{
fprintf(stderr,"cuda assert: %s %s %d\n", cudaGetErrorString(code), file, line);
if (abort) exit(code);
}
}
#endif

#if TENSORFLOW_USE_ROCM
#include<hip/hip_runtime.h>

#define hipErrcheck(res) { hipAssert((res), __FILE__, __LINE__); }
inline void hipAssert(hipError_t code, const char *file, int line, bool abort=true)
{
if (code != hipSuccess)
{
fprintf(stderr,"hip assert: %s %s %d\n", hipGetErrorString(code), file, line);
if (abort) exit(code);
}
}
#endif //TENSORFLOW_USE_ROCM

static
std::vector<int> cum_sum (const std::vector<int32> & n_sel) {
std::vector<int> sec;
Expand Down Expand Up @@ -218,32 +191,18 @@ init (const std::string & model, const int & gpu_rank, const std::string & file_
else
graph_def.ParseFromString(file_content);
int gpu_num = -1;
#if GOOGLE_CUDA
cudaGetDeviceCount(&gpu_num); // check current device environment
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
DPGetDeviceCount(gpu_num); // check current device environment
if (gpu_num > 0) {
options.config.set_allow_soft_placement(true);
options.config.mutable_gpu_options()->set_per_process_gpu_memory_fraction(0.9);
options.config.mutable_gpu_options()->set_allow_growth(true);
cudaErrcheck(cudaSetDevice(gpu_rank % gpu_num));
DPErrcheck(DPSetDevice(gpu_rank % gpu_num));
std::string str = "/gpu:";
str += std::to_string(gpu_rank % gpu_num);
graph::SetDefaultDevice(str, &graph_def);
}
#endif // GOOGLE_CUDA

#if TENSORFLOW_USE_ROCM
hipGetDeviceCount(&gpu_num); // check current device environment
if (gpu_num > 0) {
options.config.set_allow_soft_placement(true);
options.config.mutable_gpu_options()->set_per_process_gpu_memory_fraction(0.9);
options.config.mutable_gpu_options()->set_allow_growth(true);
hipErrcheck(hipSetDevice(gpu_rank % gpu_num));
std::string str = "/gpu:";
str += std::to_string(gpu_rank % gpu_num);
graph::SetDefaultDevice(str, &graph_def);
}
#endif // TENSORFLOW_USE_ROCM

#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
check_status (NewSession(options, &session));
check_status (session->Create(graph_def));
rcut = get_scalar<VALUETYPE>("descrpt_attr/rcut");
Expand Down Expand Up @@ -552,13 +511,9 @@ init (const std::vector<std::string> & models, const int & gpu_rank, const std::
graph_defs.resize(numb_models);

int gpu_num = -1;
#if GOOGLE_CUDA
cudaGetDeviceCount(&gpu_num);
#endif // GOOGLE_CUDA

#if TENSORFLOW_USE_ROCM
hipGetDeviceCount(&gpu_num);
#endif //TENSORFLOW_USE_ROCM
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
DPGetDeviceCount(gpu_num);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

SessionOptions options;
options.config.set_inter_op_parallelism_threads(num_inter_nthreads);
Expand All @@ -569,24 +524,14 @@ init (const std::vector<std::string> & models, const int & gpu_rank, const std::
else
graph_defs[ii].ParseFromString(file_contents[ii]);
}
#if GOOGLE_CUDA
if (gpu_num > 0) {
options.config.set_allow_soft_placement(true);
options.config.mutable_gpu_options()->set_per_process_gpu_memory_fraction(0.9);
options.config.mutable_gpu_options()->set_allow_growth(true);
cudaErrcheck(cudaSetDevice(gpu_rank % gpu_num));
}
#endif // GOOGLE_CUDA


#if TENSORFLOW_USE_ROCM
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
if (gpu_num > 0) {
options.config.set_allow_soft_placement(true);
options.config.mutable_gpu_options()->set_per_process_gpu_memory_fraction(0.9);
options.config.mutable_gpu_options()->set_allow_growth(true);
hipErrcheck(hipSetDevice(gpu_rank % gpu_num));
DPErrcheck(DPSetDevice(gpu_rank % gpu_num));
}
#endif // TENSORFLOW_USE_ROCM
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

for (unsigned ii = 0; ii < numb_models; ++ii) {
if (gpu_num > 0) {
Expand Down
33 changes: 20 additions & 13 deletions source/lib/include/gpu_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
#include <cuda_runtime.h>

#define GPU_MAX_NBOR_SIZE 4096
#define cudaErrcheck(res) {cudaAssert((res), __FILE__, __LINE__);}
inline void cudaAssert(cudaError_t code, const char *file, int line, bool abort=true) {
#define DPErrcheck(res) {DPAssert((res), __FILE__, __LINE__);}
inline void DPAssert(cudaError_t code, const char *file, int line, bool abort=true)
{
if (code != cudaSuccess) {
fprintf(stderr,"cuda assert: %s %s %d\n", cudaGetErrorString(code), file, line);
if (code == 2) {
Expand All @@ -27,7 +28,8 @@ inline void cudaAssert(cudaError_t code, const char *file, int line, bool abort=
}

#define nborErrcheck(res) {nborAssert((res), __FILE__, __LINE__);}
inline void nborAssert(cudaError_t code, const char *file, int line, bool abort=true) {
inline void nborAssert(cudaError_t code, const char *file, int line, bool abort=true)
{
if (code != cudaSuccess) {
fprintf(stderr,"cuda assert: %s %s %d\n", "DeePMD-kit:\tillegal nbor list sorting", file, line);
if (code == 2) {
Expand Down Expand Up @@ -65,12 +67,17 @@ static __inline__ __device__ double atomicAdd(
#endif

namespace deepmd {

inline void DPGetDeviceCount(int &gpu_num) { cudaGetDeviceCount(&gpu_num) ;}

inline cudaError_t DPSetDevice(int rank) { return cudaSetDevice(rank); }

template <typename FPTYPE>
void memcpy_host_to_device(
FPTYPE * device,
const std::vector<FPTYPE> &host)
{
cudaErrcheck(cudaMemcpy(device, &host[0], sizeof(FPTYPE) * host.size(), cudaMemcpyHostToDevice));
DPErrcheck(cudaMemcpy(device, &host[0], sizeof(FPTYPE) * host.size(), cudaMemcpyHostToDevice));
}

template <typename FPTYPE>
Expand All @@ -79,15 +86,15 @@ void memcpy_host_to_device(
const FPTYPE * host,
const int size)
{
cudaErrcheck(cudaMemcpy(device, host, sizeof(FPTYPE) * size, cudaMemcpyHostToDevice));
DPErrcheck(cudaMemcpy(device, host, sizeof(FPTYPE) * size, cudaMemcpyHostToDevice));
}

template <typename FPTYPE>
void memcpy_device_to_host(
const FPTYPE * device,
std::vector<FPTYPE> &host)
{
cudaErrcheck(cudaMemcpy(&host[0], device, sizeof(FPTYPE) * host.size(), cudaMemcpyDeviceToHost));
DPErrcheck(cudaMemcpy(&host[0], device, sizeof(FPTYPE) * host.size(), cudaMemcpyDeviceToHost));
}

template <typename FPTYPE>
Expand All @@ -96,31 +103,31 @@ void memcpy_device_to_host(
FPTYPE * host,
const int size)
{
cudaErrcheck(cudaMemcpy(host, device, sizeof(FPTYPE) * size, cudaMemcpyDeviceToHost));
DPErrcheck(cudaMemcpy(host, device, sizeof(FPTYPE) * size, cudaMemcpyDeviceToHost));
}

template <typename FPTYPE>
void malloc_device_memory(
FPTYPE * &device,
const std::vector<FPTYPE> &host)
{
cudaErrcheck(cudaMalloc((void **)&device, sizeof(FPTYPE) * host.size()));
DPErrcheck(cudaMalloc((void **)&device, sizeof(FPTYPE) * host.size()));
}

template <typename FPTYPE>
void malloc_device_memory(
FPTYPE * &device,
const int size)
{
cudaErrcheck(cudaMalloc((void **)&device, sizeof(FPTYPE) * size));
DPErrcheck(cudaMalloc((void **)&device, sizeof(FPTYPE) * size));
}

template <typename FPTYPE>
void malloc_device_memory_sync(
FPTYPE * &device,
const std::vector<FPTYPE> &host)
{
cudaErrcheck(cudaMalloc((void **)&device, sizeof(FPTYPE) * host.size()));
DPErrcheck(cudaMalloc((void **)&device, sizeof(FPTYPE) * host.size()));
memcpy_host_to_device(device, host);
}

Expand All @@ -130,7 +137,7 @@ void malloc_device_memory_sync(
const FPTYPE * host,
const int size)
{
cudaErrcheck(cudaMalloc((void **)&device, sizeof(FPTYPE) * size));
DPErrcheck(cudaMalloc((void **)&device, sizeof(FPTYPE) * size));
memcpy_host_to_device(device, host, size);
}

Expand All @@ -139,7 +146,7 @@ void delete_device_memory(
FPTYPE * &device)
{
if (device != NULL) {
cudaErrcheck(cudaFree(device));
DPErrcheck(cudaFree(device));
}
}

Expand All @@ -149,6 +156,6 @@ void memset_device_memory(
const FPTYPE var,
const int size)
{
cudaErrcheck(cudaMemset(device, var, sizeof(FPTYPE) * size));
DPErrcheck(cudaMemset(device, var, sizeof(FPTYPE) * size));
}
} // end of namespace deepmd
29 changes: 17 additions & 12 deletions source/lib/include/gpu_rocm.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

#define GPU_MAX_NBOR_SIZE 4096

#define hipErrcheck(res) { hipAssert((res), __FILE__, __LINE__); }
inline void hipAssert(hipError_t code, const char *file, int line, bool abort=true) {
#define DPErrcheck(res) { DPAssert((res), __FILE__, __LINE__); }
inline void DPAssert(hipError_t code, const char *file, int line, bool abort=true) {
if (code != hipSuccess) {
fprintf(stderr,"hip assert: %s %s %d\n", hipGetErrorString(code), file, line);
if (abort) exit(code);
Expand All @@ -24,13 +24,18 @@ inline void nborAssert(hipError_t code, const char *file, int line, bool abort=t
}
}


namespace deepmd {
inline void DPGetDeviceCount(int &gpu_num) { hipGetDeviceCount(&gpu_num) ;}

inline hipError_t DPSetDevice(int rank) { return hipSetDevice(rank); }

template <typename FPTYPE>
void memcpy_host_to_device(
FPTYPE * device,
std::vector<FPTYPE> &host)
{
hipErrcheck(hipMemcpy(device, &host[0], sizeof(FPTYPE) * host.size(), hipMemcpyHostToDevice));
DPErrcheck(hipMemcpy(device, &host[0], sizeof(FPTYPE) * host.size(), hipMemcpyHostToDevice));
}

template <typename FPTYPE>
Expand All @@ -39,47 +44,47 @@ void memcpy_host_to_device(
const FPTYPE * host,
const int size)
{
hipErrcheck(hipMemcpy(device, host, sizeof(FPTYPE) * size, hipMemcpyHostToDevice));
DPErrcheck(hipMemcpy(device, host, sizeof(FPTYPE) * size, hipMemcpyHostToDevice));
}

template <typename FPTYPE>
void memcpy_device_to_host(
FPTYPE * device,
std::vector<FPTYPE> &host)
{
hipErrcheck(hipMemcpy(&host[0], device, sizeof(FPTYPE) * host.size(), hipMemcpyDeviceToHost));
DPErrcheck(hipMemcpy(&host[0], device, sizeof(FPTYPE) * host.size(), hipMemcpyDeviceToHost));
}
template <typename FPTYPE>
void memcpy_device_to_host(
const FPTYPE * device,
FPTYPE * host,
const int size)
{
hipErrcheck(hipMemcpy(host, device, sizeof(FPTYPE) * size, hipMemcpyDeviceToHost));
DPErrcheck(hipMemcpy(host, device, sizeof(FPTYPE) * size, hipMemcpyDeviceToHost));
}

template <typename FPTYPE>
void malloc_device_memory(
FPTYPE * &device,
std::vector<FPTYPE> &host)
{
hipErrcheck(hipMalloc((void **)&device, sizeof(FPTYPE) * host.size()));
DPErrcheck(hipMalloc((void **)&device, sizeof(FPTYPE) * host.size()));
}

template <typename FPTYPE>
void malloc_device_memory(
FPTYPE * &device,
const int size)
{
hipErrcheck(hipMalloc((void **)&device, sizeof(FPTYPE) * size));
DPErrcheck(hipMalloc((void **)&device, sizeof(FPTYPE) * size));
}

template <typename FPTYPE>
void malloc_device_memory_sync(
FPTYPE * &device,
std::vector<FPTYPE> &host)
{
hipErrcheck(hipMalloc((void **)&device, sizeof(FPTYPE) * host.size()));
DPErrcheck(hipMalloc((void **)&device, sizeof(FPTYPE) * host.size()));
memcpy_host_to_device(device, host);
}
template <typename FPTYPE>
Expand All @@ -88,7 +93,7 @@ void malloc_device_memory_sync(
const FPTYPE * host,
const int size)
{
hipErrcheck(hipMalloc((void **)&device, sizeof(FPTYPE) * size));
DPErrcheck(hipMalloc((void **)&device, sizeof(FPTYPE) * size));
memcpy_host_to_device(device, host, size);
}

Expand All @@ -97,7 +102,7 @@ void delete_device_memory(
FPTYPE * &device)
{
if (device != NULL) {
hipErrcheck(hipFree(device));
DPErrcheck(hipFree(device));
}
}

Expand All @@ -107,7 +112,7 @@ void memset_device_memory(
const FPTYPE var,
const int size)
{
hipErrcheck(hipMemset(device,var,sizeof(FPTYPE)*size));
DPErrcheck(hipMemset(device,var,sizeof(FPTYPE)*size));
}
}

Expand Down
Loading

0 comments on commit 7f2fe56

Please sign in to comment.