diff --git a/source/CMakeLists.txt b/source/CMakeLists.txt index 149f70f74f..9d1ca2415a 100644 --- a/source/CMakeLists.txt +++ b/source/CMakeLists.txt @@ -54,6 +54,10 @@ list(APPEND device_srcs source_pw/module_pwdft/kernels/vnl_op.cpp source_base/kernels/math_ylm_op.cpp source_hamilt/module_xc/kernels/xc_functional_op.cpp + source_pw/module_pwdft/kernels/cal_density_real_op.cpp + source_pw/module_pwdft/kernels/mul_potential_op.cpp + source_pw/module_pwdft/kernels/vec_mul_vec_complex_op.cpp + source_pw/module_pwdft/kernels/exx_cal_energy_op.cpp ) if(USE_CUDA) @@ -80,6 +84,10 @@ if(USE_CUDA) source_base/kernels/cuda/math_kernel_op.cu source_base/kernels/cuda/math_kernel_op_vec.cu source_hamilt/module_xc/kernels/cuda/xc_functional_op.cu + source_pw/module_pwdft/kernels/cuda/cal_density_real_op.cu + source_pw/module_pwdft/kernels/cuda/mul_potential_op.cu + source_pw/module_pwdft/kernels/cuda/vec_mul_vec_complex.cu + source_pw/module_pwdft/kernels/cuda/exx_cal_energy_op.cu ) endif() diff --git a/source/Makefile.Objects b/source/Makefile.Objects index 2a3d6fa918..00abcf82be 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -342,6 +342,10 @@ OBJS_HAMILT=hamilt_pw.o\ velocity_pw.o\ radial_proj.o\ exx_helper.o\ + vec_mul_vec_complex_op.o\ + exx_cal_energy_op.o\ + cal_density_real_op.o\ + mul_potential_op.o\ OBJS_HAMILT_OF=kedf_tf.o\ kedf_vw.o\ diff --git a/source/source_base/module_container/ATen/kernels/cuda/lapack.cu b/source/source_base/module_container/ATen/kernels/cuda/lapack.cu index 32188c3e8b..54176dae51 100644 --- a/source/source_base/module_container/ATen/kernels/cuda/lapack.cu +++ b/source/source_base/module_container/ATen/kernels/cuda/lapack.cu @@ -70,8 +70,6 @@ struct lapack_trtri { { // TODO: trtri is not implemented in this method yet // Cause the trtri in cuSolver is not stable for ABACUS! - // But why?! trtri and potri are different routines for different job! - // How can BPCG work without using a proper routine? cuSolverConnector::trtri(cusolver_handle, uplo, diag, dim, Mat, lda); // cuSolverConnector::potri(cusolver_handle, uplo, diag, dim, Mat, lda); } diff --git a/source/source_base/module_container/base/third_party/cusolver.h b/source/source_base/module_container/base/third_party/cusolver.h index defd60a121..111b321e26 100644 --- a/source/source_base/module_container/base/third_party/cusolver.h +++ b/source/source_base/module_container/base/third_party/cusolver.h @@ -87,45 +87,57 @@ static inline void potrf (cusolverDnHandle_t& cusolver_handle, const char& uplo, const int& n, float * A, const int& lda) { int lwork; + int *info = nullptr; + cudaErrcheck(cudaMalloc((void**)&info, 1 * sizeof(int))); cusolverErrcheck(cusolverDnSpotrf_bufferSize(cusolver_handle, cublas_fill_mode(uplo), n, A, n, &lwork)); float* work; cudaErrcheck(cudaMalloc((void**)&work, lwork * sizeof(float))); // Perform Cholesky decomposition - cusolverErrcheck(cusolverDnSpotrf(cusolver_handle, cublas_fill_mode(uplo), n, A, n, work, lwork, nullptr)); + cusolverErrcheck(cusolverDnSpotrf(cusolver_handle, cublas_fill_mode(uplo), n, A, n, work, lwork, info)); cudaErrcheck(cudaFree(work)); + cudaErrcheck(cudaFree(info)); } static inline void potrf (cusolverDnHandle_t& cusolver_handle, const char& uplo, const int& n, double * A, const int& lda) { int lwork; + int *info = nullptr; + cudaErrcheck(cudaMalloc((void**)&info, 1 * sizeof(int))); cusolverErrcheck(cusolverDnDpotrf_bufferSize(cusolver_handle, cublas_fill_mode(uplo), n, A, n, &lwork)); double* work; cudaErrcheck(cudaMalloc((void**)&work, lwork * sizeof(double))); // Perform Cholesky decomposition - cusolverErrcheck(cusolverDnDpotrf(cusolver_handle, cublas_fill_mode(uplo), n, A, n, work, lwork, nullptr)); + cusolverErrcheck(cusolverDnDpotrf(cusolver_handle, cublas_fill_mode(uplo), n, A, n, work, lwork, info)); cudaErrcheck(cudaFree(work)); + cudaErrcheck(cudaFree(info)); } static inline void potrf (cusolverDnHandle_t& cusolver_handle, const char& uplo, const int& n, std::complex * A, const int& lda) { int lwork; - cusolverErrcheck(cusolverDnCpotrf_bufferSize(cusolver_handle, cublas_fill_mode(uplo), n, reinterpret_cast(A), n, &lwork)); + int *info = nullptr; + cudaErrcheck(cudaMalloc((void**)&info, 1 * sizeof(int))); + cusolverErrcheck(cusolverDnCpotrf_bufferSize(cusolver_handle, cublas_fill_mode(uplo), n, reinterpret_cast(A), lda, &lwork)); cuComplex* work; cudaErrcheck(cudaMalloc((void**)&work, lwork * sizeof(cuComplex))); // Perform Cholesky decomposition - cusolverErrcheck(cusolverDnCpotrf(cusolver_handle, cublas_fill_mode(uplo), n, reinterpret_cast(A), n, work, lwork, nullptr)); + cusolverErrcheck(cusolverDnCpotrf(cusolver_handle, cublas_fill_mode(uplo), n, reinterpret_cast(A), lda, work, lwork, info)); cudaErrcheck(cudaFree(work)); + cudaErrcheck(cudaFree(info)); } static inline void potrf (cusolverDnHandle_t& cusolver_handle, const char& uplo, const int& n, std::complex * A, const int& lda) { int lwork; - cusolverErrcheck(cusolverDnZpotrf_bufferSize(cusolver_handle, cublas_fill_mode(uplo), n, reinterpret_cast(A), n, &lwork)); + int *info = nullptr; + cudaErrcheck(cudaMalloc((void**)&info, 1 * sizeof(int))); + cusolverErrcheck(cusolverDnZpotrf_bufferSize(cusolver_handle, cublas_fill_mode(uplo), n, reinterpret_cast(A), lda, &lwork)); cuDoubleComplex* work; cudaErrcheck(cudaMalloc((void**)&work, lwork * sizeof(cuDoubleComplex))); // Perform Cholesky decomposition - cusolverErrcheck(cusolverDnZpotrf(cusolver_handle, cublas_fill_mode(uplo), n, reinterpret_cast(A), n, work, lwork, nullptr)); + cusolverErrcheck(cusolverDnZpotrf(cusolver_handle, cublas_fill_mode(uplo), n, reinterpret_cast(A), lda, work, lwork, info)); cudaErrcheck(cudaFree(work)); + cudaErrcheck(cudaFree(info)); } diff --git a/source/source_basis/module_pw/pw_basis.h b/source/source_basis/module_pw/pw_basis.h index b4023d991a..1578af0d83 100644 --- a/source/source_basis/module_pw/pw_basis.h +++ b/source/source_basis/module_pw/pw_basis.h @@ -432,6 +432,9 @@ class PW_Basis void set_device(std::string device_); void set_precision(std::string precision_); + std::string get_device() const { return device; } + std::string get_precision() const { return precision; } + protected: std::string device = "cpu"; ///< cpu or gpu diff --git a/source/source_esolver/esolver_ks_pw.cpp b/source/source_esolver/esolver_ks_pw.cpp index 2d3e3f0164..6ed705eb1c 100644 --- a/source/source_esolver/esolver_ks_pw.cpp +++ b/source/source_esolver/esolver_ks_pw.cpp @@ -619,6 +619,7 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int { auto start = std::chrono::high_resolution_clock::now(); exx_helper.set_firstiter(false); + exx_helper.op_exx->first_iter = false; exx_helper.set_psi(this->kspw_psi); conv_esolver = exx_helper.exx_after_converge(iter); diff --git a/source/source_io/input_conv.cpp b/source/source_io/input_conv.cpp index 37cc906c6d..952ca873d1 100644 --- a/source/source_io/input_conv.cpp +++ b/source/source_io/input_conv.cpp @@ -488,8 +488,8 @@ void Input_Conv::Convert() { if (ModuleSymmetry::Symmetry::symm_flag != -1) { - ModuleBase::WARNING("Input_Conv", "EXX PW works only with symmetry=-1"); - ModuleSymmetry::Symmetry::symm_flag = -1; + ModuleBase::WARNING_QUIT("Input_Conv", "EXX PW works only with symmetry=-1"); + // ModuleSymmetry::Symmetry::symm_flag = -1; } if (PARAM.inp.nspin != 1 && PARAM.inp.nspin != 2) @@ -497,10 +497,6 @@ void Input_Conv::Convert() ModuleBase::WARNING_QUIT("Input_Conv", "EXX PW works only with nspin=1 and 2"); } - if (PARAM.inp.device != "cpu") - { - ModuleBase::WARNING_QUIT("Input_Conv", "EXX PW works only with device=cpu"); - } } //---------------------------------------------------------- diff --git a/source/source_pw/module_pwdft/CMakeLists.txt b/source/source_pw/module_pwdft/CMakeLists.txt index e458589879..03e808f6e6 100644 --- a/source/source_pw/module_pwdft/CMakeLists.txt +++ b/source/source_pw/module_pwdft/CMakeLists.txt @@ -26,6 +26,7 @@ list(APPEND objects stress_func_nl.cpp stress_func_us.cpp stress_func_onsite.cpp + stress_func_exx.cpp stress_pw.cpp VL_in_pw.cpp VNL_in_pw.cpp @@ -47,7 +48,6 @@ add_library( module_pwdft OBJECT ${objects} - stress_func_exx.cpp ) if(ENABLE_COVERAGE) diff --git a/source/source_pw/module_pwdft/kernels/cal_density_real_op.cpp b/source/source_pw/module_pwdft/kernels/cal_density_real_op.cpp new file mode 100644 index 0000000000..b804d23f2a --- /dev/null +++ b/source/source_pw/module_pwdft/kernels/cal_density_real_op.cpp @@ -0,0 +1,25 @@ +#include "source_pw/module_pwdft/kernels/cal_density_real_op.h" +#include "source_psi/psi.h" +namespace hamilt +{ +template +struct cal_density_real_op +{ + void operator()(const T *in1, const T *in2, T *out, double omega, int nrxx) + { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) +#endif + for (int ir = 0; ir < nrxx; ir++) + { + // assert(is_finite(psi_nk_real[ir])); + // assert(is_finite(psi_mq_real[ir])); + out[ir] = in1[ir] * std::conj(in2[ir]) / static_cast(omega); // Phase e^(i(q-k)r) + } + } + +}; + +template struct cal_density_real_op, base_device::DEVICE_CPU>; +template struct cal_density_real_op, base_device::DEVICE_CPU>; +} \ No newline at end of file diff --git a/source/source_pw/module_pwdft/kernels/cal_density_real_op.h b/source/source_pw/module_pwdft/kernels/cal_density_real_op.h new file mode 100644 index 0000000000..68fa40b90c --- /dev/null +++ b/source/source_pw/module_pwdft/kernels/cal_density_real_op.h @@ -0,0 +1,14 @@ +#include "source_base/macros.h" + +#ifndef CAL_DENSITY_REAL_OP_H +#define CAL_DENSITY_REAL_OP_H +namespace hamilt +{ +template +struct cal_density_real_op +{ + using Real = typename GetTypeReal::type; + void operator()(const T *psi1, const T* psi2, T *out, double omega, int nrxx); +}; +} +#endif //CAL_DENSITY_REAL_OP_H diff --git a/source/source_pw/module_pwdft/kernels/cuda/cal_density_real_op.cu b/source/source_pw/module_pwdft/kernels/cuda/cal_density_real_op.cu new file mode 100644 index 0000000000..629134d046 --- /dev/null +++ b/source/source_pw/module_pwdft/kernels/cuda/cal_density_real_op.cu @@ -0,0 +1,48 @@ +#include "source_pw/module_pwdft/kernels/cal_density_real_op.h" +#include "source_psi/psi.h" + +#include + +namespace hamilt +{ +template +__global__ void cal_density_real_kernel( + const thrust::complex *in1, + const thrust::complex *in2, + thrust::complex *out, + const FPTYPE omega, + int nrxx) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < nrxx) + { + out[idx] = in1[idx] * thrust::conj(in2[idx]) / static_cast>(omega); + } +} + +template +struct cal_density_real_op, base_device::DEVICE_GPU> +{ + using T = std::complex; + void operator()(const T *psi1, const T *psi2, T *out, double omega, int nrxx) + { + int threads_per_block = 256; + int num_blocks = (nrxx + threads_per_block - 1) / threads_per_block; + + cal_density_real_kernel<<>>( + reinterpret_cast *>(psi1), + reinterpret_cast *>(psi2), + reinterpret_cast *>(out), + static_cast(omega), nrxx); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + throw std::runtime_error("CUDA error in cal_density_real_kernel: " + std::string(cudaGetErrorString(err))); + } + } +}; + +template struct cal_density_real_op, base_device::DEVICE_GPU>; +template struct cal_density_real_op, base_device::DEVICE_GPU>; +} // namespace hamilt \ No newline at end of file diff --git a/source/source_pw/module_pwdft/kernels/cuda/exx_cal_energy_op.cu b/source/source_pw/module_pwdft/kernels/cuda/exx_cal_energy_op.cu new file mode 100644 index 0000000000..6fe5517d78 --- /dev/null +++ b/source/source_pw/module_pwdft/kernels/cuda/exx_cal_energy_op.cu @@ -0,0 +1,84 @@ +#include "source_pw/module_pwdft/kernels/exx_cal_energy_op.h" +#include "source_psi/psi.h" + +#include + +namespace hamilt +{ + +// #ifdef _OPENMP +// #pragma omp parallel for reduction(+:Eexx_ik_real) +// #endif +// for (int ig = 0; ig < rhopw_dev->npw; ig++) +// { +// int nks = wfcpw->nks; +// int npw = rhopw_dev->npw; +// int nk = nks / nk_fac; +// Real Fac = pot[ik * nks * npw + iq * npw + ig]; + +// Eexx_ik_real += Fac * (density_recip[ig] * std::conj(density_recip[ig])).real() +// * wg_iqb_real / nqs * wg_ikb_real / kv->wk[ik]; +// } + +template +__global__ void cal_vec_norm_kernel( + const thrust::complex *den, + const FPTYPE *pot, + FPTYPE *result, + int npw) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < npw) + { + // atomicAdd(result, (den[idx].real() * den[idx].real() + den[idx].imag() * den[idx].imag()) * pot[idx]); + FPTYPE tmp =(den[idx] * thrust::conj(den[idx])).real() * pot[idx]; + atomicAdd(result, tmp); + } + __syncthreads(); +} + +template +struct exx_cal_energy_op, base_device::DEVICE_GPU> +{ + using T = std::complex; + FPTYPE operator()(const T *den, const FPTYPE *pot, FPTYPE scalar, int npw) + { + // T *den_cpu = new T[npw]; + // FPTYPE *pot_cpu = new FPTYPE[npw]; + // cudaMemcpy(den_cpu, den, npw * sizeof(T), cudaMemcpyDeviceToHost); + // cudaMemcpy(pot_cpu, pot, npw * sizeof(FPTYPE), cudaMemcpyDeviceToHost); + // FPTYPE result = exx_cal_energy_op, base_device::DEVICE_CPU>()(den_cpu, pot_cpu, scalar, npw); + // delete[] den_cpu; + // delete[] pot_cpu; + // return result; + FPTYPE result = 0.0; + + int threads_per_block = 256; + int num_blocks = (npw + threads_per_block - 1) / threads_per_block; + + FPTYPE *d_result; + cudaMalloc(&d_result, sizeof(FPTYPE)); + cudaMemset(d_result, 0, sizeof(FPTYPE)); + + cal_vec_norm_kernel<<>>( + reinterpret_cast *>(den), + pot, + d_result, + npw); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + throw std::runtime_error("CUDA error in cal_vec_norm_kernel: " + std::string(cudaGetErrorString(err))); + } + + cudaMemcpy(&result, d_result, sizeof(FPTYPE), cudaMemcpyDeviceToHost); + cudaFree(d_result); + + return scalar * result; + } +}; + +template struct exx_cal_energy_op, base_device::DEVICE_GPU>; +template struct exx_cal_energy_op, base_device::DEVICE_GPU>; +} // namespace hamilt \ No newline at end of file diff --git a/source/source_pw/module_pwdft/kernels/cuda/mul_potential_op.cu b/source/source_pw/module_pwdft/kernels/cuda/mul_potential_op.cu new file mode 100644 index 0000000000..ad148deb23 --- /dev/null +++ b/source/source_pw/module_pwdft/kernels/cuda/mul_potential_op.cu @@ -0,0 +1,53 @@ +#include "source_pw/module_pwdft/kernels/mul_potential_op.h" +#include "source_io/module_parameter/parameter.h" +#include "source_psi/psi.h" + +#include + +namespace hamilt { +template +__global__ void mul_potential_kernel( + const FPTYPE *pot_shifted, + thrust::complex *density_recip, + int npw) +{ + int ig = blockIdx.x * blockDim.x + threadIdx.x; + if (ig < npw) + { + density_recip[ig] *= pot_shifted[ig]; + } +} + +template +struct mul_potential_op, base_device::DEVICE_GPU> +{ + using T = std::complex; + void operator()(const FPTYPE *pot, T *density_recip, int npw, int nks, int ik, int iq) + { +// #ifdef _OPENMP +// #pragma omp parallel for schedule(static) +// #endif +// for (int ig = 0; ig < npw; ig++) +// { +// int ig_kq = ik * nks * npw + iq * npw + ig; +// density_recip[ig] *= pot[ig_kq]; +// +// } + int threads_per_block = 256; + int num_blocks = (npw + threads_per_block - 1) / threads_per_block; + + mul_potential_kernel<<>>( + pot + ik * nks * npw + iq * npw, + reinterpret_cast*>(density_recip), + npw); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + throw std::runtime_error("CUDA error in mul_potential_kernel: " + std::string(cudaGetErrorString(err))); + } + } +}; +template struct mul_potential_op, base_device::DEVICE_GPU>; +template struct mul_potential_op, base_device::DEVICE_GPU>; +} // hamilt \ No newline at end of file diff --git a/source/source_pw/module_pwdft/kernels/cuda/vec_mul_vec_complex.cu b/source/source_pw/module_pwdft/kernels/cuda/vec_mul_vec_complex.cu new file mode 100644 index 0000000000..262591a9de --- /dev/null +++ b/source/source_pw/module_pwdft/kernels/cuda/vec_mul_vec_complex.cu @@ -0,0 +1,47 @@ +#include "source_pw/module_pwdft/kernels/vec_mul_vec_complex_op.h" +#include "source_io/module_parameter/parameter.h" +#include "source_psi/psi.h" + +#include + +namespace hamilt { +template +__global__ void vec_mul_vec_complex_kernel( + const thrust::complex *vec1, + const thrust::complex *vec2, + thrust::complex *result, + int size) +{ + int ig = blockIdx.x * blockDim.x + threadIdx.x; + if (ig < size) + { + result[ig] = vec1[ig] * vec2[ig]; + } +} + +template +struct vec_mul_vec_complex_op, base_device::DEVICE_GPU> +{ + using T = std::complex; + void operator()(const T *a, const T *b, T* out, int size) + { + int threads_per_block = 256; + int num_blocks = (size + threads_per_block - 1) / threads_per_block; + + vec_mul_vec_complex_kernel<<>>( + reinterpret_cast *>(a), + reinterpret_cast *>(b), + reinterpret_cast *>(out), + size); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + throw std::runtime_error("CUDA error in vec_mul_vec_kernel: " + std::string(cudaGetErrorString(err))); + } + } + +}; +template struct vec_mul_vec_complex_op, base_device::DEVICE_GPU>; +template struct vec_mul_vec_complex_op, base_device::DEVICE_GPU>; +} // hamilt \ No newline at end of file diff --git a/source/source_pw/module_pwdft/kernels/exx_cal_energy_op.cpp b/source/source_pw/module_pwdft/kernels/exx_cal_energy_op.cpp new file mode 100644 index 0000000000..83ad2026e5 --- /dev/null +++ b/source/source_pw/module_pwdft/kernels/exx_cal_energy_op.cpp @@ -0,0 +1,43 @@ +#include "source_pw/module_pwdft/kernels/exx_cal_energy_op.h" +#include "source_psi/psi.h" + +namespace hamilt +{ + +// #ifdef _OPENMP +// #pragma omp parallel for reduction(+:Eexx_ik_real) +// #endif +// for (int ig = 0; ig < rhopw_dev->npw; ig++) +// { +// int nks = wfcpw->nks; +// int npw = rhopw_dev->npw; +// int nk = nks / nk_fac; +// Real Fac = pot[ik * nks * npw + iq * npw + ig]; + +// Eexx_ik_real += Fac * (density_recip[ig] * std::conj(density_recip[ig])).real() +// * wg_iqb_real / nqs * wg_ikb_real / kv->wk[ik]; +// } + +template +struct exx_cal_energy_op, base_device::DEVICE_CPU> +{ + using T = std::complex; + FPTYPE operator()(const T *den, const FPTYPE *pot, FPTYPE scalar, int npw) + { + FPTYPE energy = 0.0; + #ifdef _OPENMP + #pragma omp parallel for reduction(+:energy) + #endif + for (int ig = 0; ig < npw; ++ig) + { + // Calculate the energy contribution from each reciprocal lattice vector + energy += (den[ig] * std::conj(den[ig])).real() * pot[ig]; + } + // Scale the energy by the scalar factor + return scalar * energy; + } +}; + +template struct exx_cal_energy_op, base_device::DEVICE_CPU>; +template struct exx_cal_energy_op, base_device::DEVICE_CPU>; +} // namespace hamilt diff --git a/source/source_pw/module_pwdft/kernels/exx_cal_energy_op.h b/source/source_pw/module_pwdft/kernels/exx_cal_energy_op.h new file mode 100644 index 0000000000..50e694d97b --- /dev/null +++ b/source/source_pw/module_pwdft/kernels/exx_cal_energy_op.h @@ -0,0 +1,14 @@ +#ifndef CAL_VEC_NORM_OP_H +#define CAL_VEC_NORM_OP_H +#include "source_base/macros.h" +namespace hamilt{ +template +struct exx_cal_energy_op +{ + + using FPTYPE = typename GetTypeReal::type; + FPTYPE operator()(const T *den, const FPTYPE *pot, FPTYPE scala, int npw); +}; + +} // namespace hamilt +#endif //CAL_VEC_NORM_OP_H diff --git a/source/source_pw/module_pwdft/kernels/mul_potential_op.cpp b/source/source_pw/module_pwdft/kernels/mul_potential_op.cpp new file mode 100644 index 0000000000..85647e065d --- /dev/null +++ b/source/source_pw/module_pwdft/kernels/mul_potential_op.cpp @@ -0,0 +1,26 @@ +#include "source_pw/module_pwdft/kernels/mul_potential_op.h" +#include "source_io/module_parameter/parameter.h" +#include "source_psi/psi.h" +#include "source_base/macros.h" +namespace hamilt { +template +struct mul_potential_op, base_device::DEVICE_CPU> +{ + using T = std::complex; + void operator()(const FPTYPE *pot, T *density_recip, int npw, int nks, int ik, int iq) + { + #ifdef _OPENMP + #pragma omp parallel for schedule(static) + #endif + for (int ig = 0; ig < npw; ig++) + { + int ig_kq = ik * nks * npw + iq * npw + ig; + density_recip[ig] *= pot[ig_kq]; + + } + } +}; + +template struct mul_potential_op, base_device::DEVICE_CPU>; +template struct mul_potential_op, base_device::DEVICE_CPU>; +} // hamilt \ No newline at end of file diff --git a/source/source_pw/module_pwdft/kernels/mul_potential_op.h b/source/source_pw/module_pwdft/kernels/mul_potential_op.h new file mode 100644 index 0000000000..b1cd647b19 --- /dev/null +++ b/source/source_pw/module_pwdft/kernels/mul_potential_op.h @@ -0,0 +1,30 @@ +#ifndef MUL_POTENTIAL_OP_H +#define MUL_POTENTIAL_OP_H +#include "source_psi/psi.h" +#include "source_base/macros.h" + +namespace hamilt { + +template +struct mul_potential_op +{ +// int npw = rhopw_dev->npw; +// int nks = wfcpw->nks; +// int nk_fac = PARAM.inp.nspin == 2 ? 2 : 1; +// int nk = nks / nk_fac; +// +// #ifdef _OPENMP +// #pragma omp parallel for schedule(static) +// #endif +// for (int ig = 0; ig < npw; ig++) +// { +// int ig_kq = ik * nks * npw + iq * npw + ig; +// density_recip[ig] *= pot[ig_kq]; +// } + using FPTYPE = typename GetTypeReal::type; + void operator()(const FPTYPE *pot, T *density_recip, int npw, int nks, int ik, int iq); +}; + +} // namespace hamilt + +#endif // MUL_POTENTIAL_OP_H diff --git a/source/source_pw/module_pwdft/kernels/vec_mul_vec_complex_op.cpp b/source/source_pw/module_pwdft/kernels/vec_mul_vec_complex_op.cpp new file mode 100644 index 0000000000..5798143740 --- /dev/null +++ b/source/source_pw/module_pwdft/kernels/vec_mul_vec_complex_op.cpp @@ -0,0 +1,23 @@ +#include "source_pw/module_pwdft/kernels/vec_mul_vec_complex_op.h" +#include "source_io/module_parameter/parameter.h" +#include "source_psi/psi.h" +namespace hamilt{ +template +struct vec_mul_vec_complex_op, base_device::DEVICE_CPU> +{ + using T = std::complex; + void operator()(const T *vec1, const T *vec2, T *out, int n) + { + #ifdef _OPENMP + #pragma omp parallel for schedule(static) + #endif + for (int i = 0; i < n; i++) + { + out[i] = vec1[i] * vec2[i]; + } + } + +}; +template struct vec_mul_vec_complex_op, base_device::DEVICE_CPU>; +template struct vec_mul_vec_complex_op, base_device::DEVICE_CPU>; +} // hamilt \ No newline at end of file diff --git a/source/source_pw/module_pwdft/kernels/vec_mul_vec_complex_op.h b/source/source_pw/module_pwdft/kernels/vec_mul_vec_complex_op.h new file mode 100644 index 0000000000..0618f37c7b --- /dev/null +++ b/source/source_pw/module_pwdft/kernels/vec_mul_vec_complex_op.h @@ -0,0 +1,16 @@ +// +// Created by rhx on 25-6-26. +// + +#ifndef VEC_MUL_VEC_OP_H +#define VEC_MUL_VEC_OP_H +namespace hamilt { + +template +struct vec_mul_vec_complex_op +{ + // Multiply a vector with a complex vector + void operator()(const T *vec1, const T *vec2, T *out, int n); +}; +} // namespace hamilt +#endif //VEC_MUL_VEC_OP_H diff --git a/source/source_pw/module_pwdft/operator_pw/op_exx_pw.cpp b/source/source_pw/module_pwdft/operator_pw/op_exx_pw.cpp index 7e294754b8..68bfca3158 100644 --- a/source/source_pw/module_pwdft/operator_pw/op_exx_pw.cpp +++ b/source/source_pw/module_pwdft/operator_pw/op_exx_pw.cpp @@ -1,3 +1,5 @@ +#include "op_exx_pw.h" + #include "source_base/constants.h" #include "source_base/global_variable.h" #include "source_base/parallel_common.h" @@ -8,44 +10,19 @@ #include "source_cell/klist.h" #include "source_hamilt/operator.h" #include "source_psi/psi.h" +#include "source_pw/module_pwdft/global.h" +#include "source_pw/module_pwdft/kernels/cal_density_real_op.h" +#include "source_pw/module_pwdft/kernels/exx_cal_energy_op.h" +#include "source_pw/module_pwdft/kernels/mul_potential_op.h" +#include "source_pw/module_pwdft/kernels/vec_mul_vec_complex_op.h" #include #include #include -#include #include -extern "C" -{ - void ztrtri_(char *uplo, char *diag, int *n, std::complex *a, int *lda, int *info); - void ctrtri_(char *uplo, char *diag, int *n, std::complex *a, int *lda, int *info); -} - -//extern "C" void zpotrf_(char* uplo, const int* n, std::complex* A, const int* lda, int* info); -//extern "C" void cpotrf_(char* uplo, const int* n, std::complex* A, const int* lda, int* info); - -#include "op_exx_pw.h" -#include "source_pw/module_pwdft/global.h" - namespace hamilt { -template -struct trtri_op -{ - void operator()(char *uplo, char *diag, int *n, T *a, int *lda, int *info) - { - std::cout << "trtri_op not implemented" << std::endl; - } -}; - -template -struct potrf_op -{ - void operator()(char *uplo, int *n, T *a, int *lda, int *info) - { - std::cout << "potrf_op not implemented" << std::endl; - } -}; template OperatorEXXPW::OperatorEXXPW(const int* isk_in, @@ -93,9 +70,20 @@ OperatorEXXPW::OperatorEXXPW(const int* isk_in, tpiba = ucell->tpiba; Real tpiba2 = tpiba * tpiba; - // calculate the exx_divergence -} + rhopw_dev = new ModulePW::PW_Basis(wfcpw->get_device(), rhopw->get_precision()); + rhopw_dev->fft_bundle.setfft(wfcpw->get_device(), rhopw->get_precision()); +#ifdef __MPI + rhopw_dev->initmpi(rhopw->poolnproc, rhopw->poolrank, rhopw->pool_world); +#endif + // here we can actually use different ecut to init the grids + rhopw_dev->initgrids(rhopw->lat0, rhopw->latvec, rhopw->gridecut_lat * rhopw->tpiba2); + rhopw_dev->initgrids(rhopw->lat0, rhopw->latvec, rhopw->nx, rhopw->ny, rhopw->nz); + rhopw_dev->initparameters(rhopw->gamma_only, rhopw->ggecut * rhopw->tpiba2, rhopw->distribution_type, rhopw->xprime); + rhopw_dev->setuptransform(); + rhopw_dev->collect_local_pw(); + +} // end of constructor template OperatorEXXPW::~OperatorEXXPW() @@ -118,7 +106,7 @@ OperatorEXXPW::~OperatorEXXPW() delmem_complex_op()(Xi_ace); } Xi_ace_k.clear(); - + delete rhopw_dev; } template @@ -191,9 +179,9 @@ void OperatorEXXPW::act_op(const int nbands, ModuleBase::timer::tick("OperatorEXXPW", "act_op"); setmem_complex_op()(h_psi_recip, 0, wfcpw->npwk_max); - setmem_complex_op()(h_psi_real, 0, rhopw->nrxx); - setmem_complex_op()(density_real, 0, rhopw->nrxx); - setmem_complex_op()(density_recip, 0, rhopw->npw); + setmem_complex_op()(h_psi_real, 0, rhopw_dev->nrxx); + setmem_complex_op()(density_real, 0, rhopw_dev->nrxx); + setmem_complex_op()(density_recip, 0, rhopw_dev->npw); // setmem_complex_op()(psi_all_real, 0, wfcpw->nrxx * GlobalV::NBANDS); // std::map, bool> has_real; setmem_complex_op()(psi_nk_real, 0, wfcpw->nrxx); @@ -222,79 +210,57 @@ void OperatorEXXPW::act_op(const int nbands, continue; } - // if (has_real.find({iq, m_iband}) == has_real.end()) - // { - const T* psi_mq = get_pw(m_iband, iq); - wfcpw->recip_to_real(ctx, psi_mq, psi_mq_real, iq); - // syncmem_complex_op()(this->ctx, this->ctx, psi_all_real + m_iband * wfcpw->nrxx, psi_mq_real, wfcpw->nrxx); - // has_real[{iq, m_iband}] = true; - // } - // else - // { - // // const T* psi_mq = get_pw(m_iband, iq); - // // wfcpw->recip_to_real(ctx, psi_mq, psi_mq_real, iq); - // syncmem_complex_op()(this->ctx, this->ctx, psi_mq_real, psi_all_real + m_iband * wfcpw->nrxx, wfcpw->nrxx); - // } - + const T* psi_mq = get_pw(m_iband, iq); + wfcpw->recip_to_real(ctx, psi_mq, psi_mq_real, iq); + // direct multiplication in real space, \psi_nk(r) * \psi_mq(r) - #ifdef _OPENMP - #pragma omp parallel for schedule(static) - #endif - for (int ir = 0; ir < wfcpw->nrxx; ir++) - { - // assert(is_finite(psi_nk_real[ir])); - // assert(is_finite(psi_mq_real[ir])); - Real ucell_omega = ucell->omega; - density_real[ir] = psi_nk_real[ir] * std::conj(psi_mq_real[ir]) / ucell_omega; // Phase e^(i(q-k)r) - } - // to be changed into kernel function - + cal_density_recip(psi_nk_real, psi_mq_real, ucell->omega); + // bring the density to recip space - rhopw->real2recip(density_real, density_recip); + // rhopw->real2recip(density_real, density_recip); // multiply the density with the potential in recip space multiply_potential(density_recip, this->ik, iq); // bring the potential back to real space - rhopw->recip2real(density_recip, density_real); + // rhopw_dev->recip2real(density_recip, density_real); + rho_recip2real(density_recip, density_real); - // get the h|psi_ik>(r), save in density_real - #ifdef _OPENMP - #pragma omp parallel for schedule(static) - #endif - for (int ir = 0; ir < wfcpw->nrxx; ir++) + if (false) { - // assert(is_finite(psi_mq_real[ir])); - // assert(is_finite(density_real[ir])); - density_real[ir] *= psi_mq_real[ir]; + // do nothing + } + else + { + vec_mul_vec_complex_op()(density_real, psi_mq_real, density_real, wfcpw->nrxx); } T wk_iq = kv->wk[iq]; T wk_ik = kv->wk[this->ik]; - #ifdef _OPENMP - #pragma omp parallel for schedule(static) - #endif - for (int ir = 0; ir < wfcpw->nrxx; ir++) - { - h_psi_real[ir] += density_real[ir] * wg_mqb / wk_iq / nqs; - } + T tmp_scalar = wg_mqb / wk_iq / nqs; + axpy_complex_op()(wfcpw->nrxx, + &tmp_scalar, + density_real, + 1, + h_psi_real, + 1); } // end of m_iband - setmem_complex_op()(density_real, 0, rhopw->nrxx); - setmem_complex_op()(density_recip, 0, rhopw->npw); + setmem_complex_op()(density_real, 0, rhopw_dev->nrxx); + setmem_complex_op()(density_recip, 0, rhopw_dev->npw); setmem_complex_op()(psi_mq_real, 0, wfcpw->nrxx); } // end of iq T* h_psi_nk = tmhpsi + n_iband * nbasis; Real hybrid_alpha = GlobalC::exx_info.info_global.hybrid_alpha; wfcpw->real_to_recip(ctx, h_psi_real, h_psi_nk, this->ik, true, hybrid_alpha); - setmem_complex_op()(h_psi_real, 0, rhopw->nrxx); - + setmem_complex_op()(h_psi_real, 0, rhopw_dev->nrxx); + } ModuleBase::timer::tick("OperatorEXXPW", "act_op"); - + } template @@ -307,7 +273,6 @@ void OperatorEXXPW::act_op_ace(const int nbands, const bool is_first_node) const { ModuleBase::timer::tick("OperatorEXXPW", "act_op_ace"); - // std::cout << "act_op_ace" << std::endl; // hpsi += -Xi^\dagger * Xi * psi T* Xi_ace = Xi_ace_k[this->ik]; @@ -410,6 +375,8 @@ void OperatorEXXPW::construct_ace() const resmem_complex_op()(psi_h_psi_ace, nbands * nbands); } + if (first_iter) return; + for (int ik = 0; ik < nk; ik++) { int npwk = wfcpw->npwk[ik]; @@ -451,43 +418,28 @@ void OperatorEXXPW::construct_ace() const // reduction of psi_h_psi_ace, due to distributed memory Parallel_Reduce::reduce_pool(psi_h_psi_ace, nbands * nbands); - // L_ace = cholesky(-psi_h_psi_ace) - #ifdef _OPENMP - #pragma omp parallel for schedule(static) - #endif - for (int i = 0; i < nbands; i++) - { - for (int j = 0; j < nbands; j++) - { - L_ace[i * nbands + j] = -psi_h_psi_ace[i * nbands + j]; - } - } + T intermediate_minus_one = -1.0; + axpy_complex_op()(nbands * nbands, + &intermediate_minus_one, + psi_h_psi_ace, + 1, + L_ace, + 1); + int info = 0; char up = 'U', lo = 'L'; - potrf_op()(&lo, &nbands, L_ace, &nbands, &info); + lapack_potrf()(lo, nbands, L_ace, nbands); // expand for-loop - #ifdef _OPENMP - #pragma omp parallel for schedule(static) collapse(2) - #endif - for (int i = 0; i < nbands; i++) - { - for (int j = 0; j < nbands; j++) - { - if (j < i) - { - // L_ace[j * nkb + i] = std::conj(L_ace[i * nkb + j]); - L_ace[i * nbands + j] = 0.0; - } - } + for (int i = 0; i < nbands; ++i) { + setmem_complex_op()(L_ace + i * nbands, 0, i); } // L_ace inv in place - // T == std::complex or std::complex char non = 'N'; - trtri_op()(&lo, &non, &nbands, L_ace, &nbands, &info); + lapack_trtri()(lo, non, nbands, L_ace, nbands); // Xi_ace = L_ace^-1 * h_psi_ace^dagger gemm_complex_op()('N', @@ -566,20 +518,12 @@ template void OperatorEXXPW::multiply_potential(T *density_recip, int ik, int iq) const { ModuleBase::timer::tick("OperatorEXXPW", "multiply_potential"); - int npw = rhopw->npw; + int npw = rhopw_dev->npw; int nks = wfcpw->nks; int nk_fac = PARAM.inp.nspin == 2 ? 2 : 1; int nk = nks / nk_fac; - #ifdef _OPENMP - #pragma omp parallel for schedule(static) - #endif - for (int ig = 0; ig < npw; ig++) - { - int ig_kq = ik * nks * npw + iq * npw + ig; - density_recip[ig] *= pot[ig_kq]; - - } + mul_potential_op()(pot, density_recip, npw, nks, ik, iq); ModuleBase::timer::tick("OperatorEXXPW", "multiply_potential"); } @@ -601,14 +545,15 @@ OperatorEXXPW::OperatorEXXPW(const OperatorEXXPW *op this->isk = op->isk; this->wfcpw = op->wfcpw; this->rhopw = op->rhopw; + this->rhopw_dev = op->rhopw_dev; this->psi = op->psi; this->ctx = op->ctx; this->cpu_ctx = op->cpu_ctx; resmem_complex_op()(this->ctx, psi_nk_real, wfcpw->nrxx); resmem_complex_op()(this->ctx, psi_mq_real, wfcpw->nrxx); - resmem_complex_op()(this->ctx, density_real, rhopw->nrxx); - resmem_complex_op()(this->ctx, h_psi_real, rhopw->nrxx); - resmem_complex_op()(this->ctx, density_recip, rhopw->npw); + resmem_complex_op()(this->ctx, density_real, rhopw_dev->nrxx); + resmem_complex_op()(this->ctx, h_psi_real, rhopw_dev->nrxx); + resmem_complex_op()(this->ctx, density_recip, rhopw_dev->npw); resmem_complex_op()(this->ctx, h_psi_recip, wfcpw->npwk_max); // this->pws.resize(wfcpw->nks); @@ -622,9 +567,13 @@ void OperatorEXXPW::get_potential() const Real nqs_half2 = 0.5 * kv->nmp[1]; Real nqs_half3 = 0.5 * kv->nmp[2]; - setmem_real_op()(pot, 0, rhopw->npw * wfcpw->nks * wfcpw->nks); - int nks = wfcpw->nks, npw = rhopw->npw; + Real* pot_cpu = nullptr; + int nks = wfcpw->nks, npw = rhopw_dev->npw; double tpiba2 = tpiba * tpiba; + pot_cpu = new Real[npw * wfcpw->nks * wfcpw->nks]; + // fill zero + setmem_real_cpu_op()(pot_cpu, 0, npw * nks * nks); + // calculate Fock pot auto param_fock = GlobalC::exx_info.info_global.coulomb_param[Conv_Coulomb_Pot_K::Coulomb_Type::Fock]; for (auto param : param_fock) @@ -643,9 +592,9 @@ void OperatorEXXPW::get_potential() const #ifdef _OPENMP #pragma omp parallel for schedule(static) #endif - for (int ig = 0; ig < rhopw->npw; ig++) + for (int ig = 0; ig < rhopw_dev->npw; ig++) { - const ModuleBase::Vector3 g_d = rhopw->gdirect[ig]; + const ModuleBase::Vector3 g_d = rhopw_dev->gdirect[ig]; const ModuleBase::Vector3 kqg_d = k_d - q_d + g_d; // For gamma_extrapolation (https://doi.org/10.1103/PhysRevB.79.205114) // 7/8 of the points in the grid are "activated" and 1/8 are disabled. @@ -673,17 +622,17 @@ void OperatorEXXPW::get_potential() const const int nk = nks / nk_fac; const int ig_kq = ik * nks * npw + iq * npw + ig; - Real gg = (k_c - q_c + rhopw->gcar[ig]).norm2() * tpiba2; + Real gg = (k_c - q_c + rhopw_dev->gcar[ig]).norm2() * tpiba2; // if (kqgcar2 > 1e-12) // vasp uses 1/40 of the smallest (k spacing)**2 if (gg >= 1e-8) { Real fac = -ModuleBase::FOUR_PI * ModuleBase::e2 / gg; - pot[ig_kq] += fac * grid_factor * alpha; + pot_cpu[ig_kq] += fac * grid_factor * alpha; } // } else { - pot[ig_kq] += exx_div * alpha; + pot_cpu[ig_kq] += exx_div * alpha; } // assert(is_finite(density_recip[ig])); } @@ -711,9 +660,9 @@ void OperatorEXXPW::get_potential() const #ifdef _OPENMP #pragma omp parallel for schedule(static) #endif - for (int ig = 0; ig < rhopw->npw; ig++) + for (int ig = 0; ig < rhopw_dev->npw; ig++) { - const ModuleBase::Vector3 g_d = rhopw->gdirect[ig]; + const ModuleBase::Vector3 g_d = rhopw_dev->gdirect[ig]; const ModuleBase::Vector3 kqg_d = k_d - q_d + g_d; // For gamma_extrapolation (https://doi.org/10.1103/PhysRevB.79.205114) // 7/8 of the points in the grid are "activated" and 1/8 are disabled. @@ -741,12 +690,12 @@ void OperatorEXXPW::get_potential() const const int nk = nks / nk_fac; const int ig_kq = ik * nks * npw + iq * npw + ig; - Real gg = (k_c - q_c + rhopw->gcar[ig]).norm2() * tpiba2; + Real gg = (k_c - q_c + rhopw_dev->gcar[ig]).norm2() * tpiba2; // if (kqgcar2 > 1e-12) // vasp uses 1/40 of the smallest (k spacing)**2 if (gg >= 1e-8) { Real fac = -ModuleBase::FOUR_PI * ModuleBase::e2 / gg; - pot[ig_kq] += fac * (1.0 - std::exp(-gg / 4.0 / erfc_omega2)) * grid_factor * alpha; + pot_cpu[ig_kq] += fac * (1.0 - std::exp(-gg / 4.0 / erfc_omega2)) * grid_factor * alpha; } // } else @@ -754,11 +703,11 @@ void OperatorEXXPW::get_potential() const // if (PARAM.inp.dft_functional == "hse") if (!gamma_extrapolation) { - pot[ig_kq] += (exx_div - ModuleBase::PI * ModuleBase::e2 / erfc_omega2) * alpha; + pot_cpu[ig_kq] += (exx_div - ModuleBase::PI * ModuleBase::e2 / erfc_omega2) * alpha; } else { - pot[ig_kq] += exx_div * alpha; + pot_cpu[ig_kq] += exx_div * alpha; } } // assert(is_finite(density_recip[ig])); @@ -766,6 +715,11 @@ void OperatorEXXPW::get_potential() const } } } + + // copy the potential to the device memory + syncmem_real_c2d_op()(pot, pot_cpu, rhopw_dev->npw * wfcpw->nks * wfcpw->nks); + + delete pot_cpu; } template @@ -793,10 +747,10 @@ double OperatorEXXPW::exx_divergence(Conv_Coulomb_Pot_K::Coulomb_Type #ifdef _OPENMP #pragma omp parallel for reduction(+:div) #endif - for (int ig = 0; ig < rhopw->npw; ig++) + for (int ig = 0; ig < rhopw_dev->npw; ig++) { - const ModuleBase::Vector3 q_c = k_c + rhopw->gcar[ig]; - const ModuleBase::Vector3 q_d = k_d + rhopw->gdirect[ig]; + const ModuleBase::Vector3 q_c = k_c + rhopw_dev->gcar[ig]; + const ModuleBase::Vector3 q_d = k_d + rhopw_dev->gdirect[ig]; double qq = q_c.norm2(); // For gamma_extrapolation (https://doi.org/10.1103/PhysRevB.79.205114) // 7/8 of the points in the grid are "activated" and 1/8 are disabled. @@ -946,12 +900,12 @@ double OperatorEXXPW::cal_exx_energy_op(psi::Psi *ppsi_) c using setmem_complex_op = base_device::memory::set_memory_op; using delmem_complex_op = base_device::memory::delete_memory_op; - T* psi_nk_real = new T[wfcpw->nrxx]; - T* psi_mq_real = new T[wfcpw->nrxx]; - T* h_psi_recip = new T[wfcpw->npwk_max]; - T* h_psi_real = new T[wfcpw->nrxx]; - T* density_real = new T[wfcpw->nrxx]; - T* density_recip = new T[rhopw->npw]; + setmem_complex_op()(psi_nk_real, 0, wfcpw->nrxx); + setmem_complex_op()(psi_mq_real, 0, wfcpw->nrxx); + setmem_complex_op()(h_psi_recip, 0, wfcpw->npwk_max); + setmem_complex_op()(h_psi_real, 0, rhopw_dev->nrxx); + setmem_complex_op()(density_real, 0, rhopw_dev->nrxx); + setmem_complex_op()(density_recip, 0, rhopw_dev->npw); if (wg == nullptr) return 0.0; const int nk_fac = PARAM.inp.nspin == 2 ? 2 : 1; @@ -963,9 +917,9 @@ double OperatorEXXPW::cal_exx_energy_op(psi::Psi *ppsi_) c for (int n_iband = 0; n_iband < psi.get_nbands(); n_iband++) { setmem_complex_op()(h_psi_recip, 0, wfcpw->npwk_max); - setmem_complex_op()(h_psi_real, 0, rhopw->nrxx); - setmem_complex_op()(density_real, 0, rhopw->nrxx); - setmem_complex_op()(density_recip, 0, rhopw->npw); + setmem_complex_op()(h_psi_real, 0, rhopw_dev->nrxx); + setmem_complex_op()(density_real, 0, rhopw_dev->nrxx); + setmem_complex_op()(density_recip, 0, rhopw_dev->npw); // double wg_ikb_real = GlobalC::exx_helper.wg(this->ik, n_iband); double wg_ikb_real = (*wg)(ik, n_iband); @@ -1007,35 +961,12 @@ double OperatorEXXPW::cal_exx_energy_op(psi::Psi *ppsi_) c // const T* psi_mq = get_pw(m_iband, iq); wfcpw->recip_to_real(ctx, psi_mq, psi_mq_real, iq); - T omega_inv = 1.0 / ucell->omega; + cal_density_recip(psi_nk_real, psi_mq_real, ucell->omega); - // direct multiplication in real space, \psi_nk(r) * \psi_mq(r) - #ifdef _OPENMP - #pragma omp parallel for - #endif - for (int ir = 0; ir < wfcpw->nrxx; ir++) - { - // assert(is_finite(psi_nk_real[ir])); - // assert(is_finite(psi_mq_real[ir])); - density_real[ir] = psi_nk_real[ir] * std::conj(psi_mq_real[ir]) * omega_inv; - } - // to be changed into kernel function - - // bring the density to recip space - rhopw->real2recip(density_real, density_recip); - - #ifdef _OPENMP - #pragma omp parallel for reduction(+:Eexx_ik_real) - #endif - for (int ig = 0; ig < rhopw->npw; ig++) - { - int nks = wfcpw->nks; - int npw = rhopw->npw; - int nk = nks / nk_fac; - Real Fac = pot[ik * nks * npw + iq * npw + ig]; - Eexx_ik_real += Fac * (density_recip[ig] * std::conj(density_recip[ig])).real() - * wg_iqb_real / nqs * wg_ikb_real / kv->wk[ik]; - } + int nks = wfcpw->nks; + int npw = rhopw_dev->npw; + int nk = nks / nk_fac; + Eexx_ik_real += exx_cal_energy_op()(density_recip, pot + ik * nks * npw + iq * npw, wg_iqb_real / nqs * wg_ikb_real / kv->wk[ik], npw); } // m_iband @@ -1048,39 +979,50 @@ double OperatorEXXPW::cal_exx_energy_op(psi::Psi *ppsi_) c Parallel_Reduce::reduce_pool(Eexx_ik_real); // std::cout << "omega = " << this_->pelec->omega << " tpiba = " << this_->pw_rho->tpiba2 << " exx_div = " << exx_div << std::endl; - delete[] psi_nk_real; - delete[] psi_mq_real; - delete[] h_psi_recip; - delete[] h_psi_real; - delete[] density_real; - delete[] density_recip; + setmem_complex_op()(psi_nk_real, 0, wfcpw->nrxx); + setmem_complex_op()(psi_mq_real, 0, wfcpw->nrxx); + setmem_complex_op()(h_psi_recip, 0, wfcpw->npwk_max); + setmem_complex_op()(h_psi_real, 0, rhopw_dev->nrxx); + setmem_complex_op()(density_real, 0, rhopw_dev->nrxx); + setmem_complex_op()(density_recip, 0, rhopw_dev->npw); - double Eexx = Eexx_ik_real; - return Eexx; + return Eexx_ik_real; } template <> -void trtri_op, base_device::DEVICE_CPU>::operator()(char *uplo, char *diag, int *n, std::complex *a, int *lda, int *info) +void OperatorEXXPW, base_device::DEVICE_CPU>::cal_density_recip(const std::complex* psi_nk_real, + const std::complex* psi_mq_real, + double omega) const { - ctrtri_(uplo, diag, n, a, lda, info); + cal_density_real_op, base_device::DEVICE_CPU>()(psi_nk_real, psi_mq_real, density_real, omega, wfcpw->nrxx); + rhopw_dev->real2recip(density_real, density_recip); } template <> -void trtri_op, base_device::DEVICE_CPU>::operator()(char *uplo, char *diag, int *n, std::complex *a, int *lda, int *info) +void OperatorEXXPW, base_device::DEVICE_CPU>::cal_density_recip(const std::complex* psi_nk_real, + const std::complex* psi_mq_real, + double omega) const { - ztrtri_(uplo, diag, n, a, lda, info); + cal_density_real_op, base_device::DEVICE_CPU>()(psi_nk_real, psi_mq_real, density_real, omega, wfcpw->nrxx); + rhopw_dev->real2recip(density_real, density_recip); } template <> -void potrf_op, base_device::DEVICE_CPU>::operator()(char *uplo, int *n, std::complex *a, int *lda, int *info) +void OperatorEXXPW, base_device::DEVICE_CPU>::rho_recip2real(const std::complex* rho_recip, + std::complex* rho_real, + bool add, + double factor) const { - cpotrf_(uplo, n, a, lda, info); + rhopw_dev->recip2real(rho_recip, rho_real, add, factor); } template <> -void potrf_op, base_device::DEVICE_CPU>::operator()(char *uplo, int *n, std::complex *a, int *lda, int *info) +void OperatorEXXPW, base_device::DEVICE_CPU>::rho_recip2real(const std::complex* rho_recip, + std::complex* rho_real, + bool add, + float factor) const { - zpotrf_(uplo, n, a, lda, info); + rhopw_dev->recip2real(rho_recip, rho_real, add, factor); } template class OperatorEXXPW, base_device::DEVICE_CPU>; @@ -1088,6 +1030,43 @@ template class OperatorEXXPW, base_device::DEVICE_CPU>; #if ((defined __CUDA) || (defined __ROCM)) template class OperatorEXXPW, base_device::DEVICE_GPU>; template class OperatorEXXPW, base_device::DEVICE_GPU>; + +template <> +void OperatorEXXPW, base_device::DEVICE_GPU>::cal_density_recip(const std::complex* psi_nk_real, + const std::complex* psi_mq_real, + double omega) const +{ + cal_density_real_op, base_device::DEVICE_GPU>()(psi_nk_real, psi_mq_real, density_real, omega, wfcpw->nrxx); + rhopw_dev->real2recip_gpu(density_real, density_recip); +} + +template <> +void OperatorEXXPW, base_device::DEVICE_GPU>::cal_density_recip(const std::complex* psi_nk_real, + const std::complex* psi_mq_real, + double omega) const +{ + cal_density_real_op, base_device::DEVICE_GPU>()(psi_nk_real, psi_mq_real, density_real, omega, wfcpw->nrxx); + rhopw_dev->real2recip_gpu(density_real, density_recip); +} + +template <> +void OperatorEXXPW, base_device::DEVICE_GPU>::rho_recip2real(const std::complex* rho_recip, + std::complex* rho_real, + bool add, + double factor) const +{ + rhopw_dev->recip2real_gpu(rho_recip, rho_real, add, factor); +} + +template <> +void OperatorEXXPW, base_device::DEVICE_GPU>::rho_recip2real(const std::complex* rho_recip, + std::complex* rho_real, + bool add, + float factor) const +{ + rhopw_dev->recip2real_gpu(rho_recip, rho_real, add, factor); +} + #endif } // namespace hamilt diff --git a/source/source_pw/module_pwdft/operator_pw/op_exx_pw.h b/source/source_pw/module_pwdft/operator_pw/op_exx_pw.h index ee9501df66..51b69b13a9 100644 --- a/source/source_pw/module_pwdft/operator_pw/op_exx_pw.h +++ b/source/source_pw/module_pwdft/operator_pw/op_exx_pw.h @@ -10,6 +10,7 @@ #include "source_cell/klist.h" #include "source_lcao/module_ri/conv_coulomb_pot_k.h" #include "source_psi/psi.h" +#include "source_base/module_container/ATen/kernels/lapack.h" #include #include @@ -55,9 +56,10 @@ class OperatorEXXPW : public OperatorPW bool first_iter = false; private: - const int *isk = nullptr; - const ModulePW::PW_Basis_K *wfcpw = nullptr; - const ModulePW::PW_Basis *rhopw = nullptr; + const int* isk = nullptr; + const ModulePW::PW_Basis_K* wfcpw = nullptr; + const ModulePW::PW_Basis* rhopw = nullptr; + ModulePW::PW_Basis* rhopw_dev = nullptr; // for device const UnitCell *ucell = nullptr; // Real exx_div = 0; Real tpiba = 0; @@ -91,6 +93,10 @@ class OperatorEXXPW : public OperatorPW double cal_exx_energy_ace(psi::Psi *psi_) const; + void cal_density_recip(const T* psi_nk_real, const T* psi_mq_real, double omega) const; + + void rho_recip2real(const T* rho_recip, T* rho_real, bool add = false, Real factor = 1.0) const; + mutable int cnt = 0; mutable bool potential_got = false; @@ -134,16 +140,25 @@ class OperatorEXXPW : public OperatorPW base_device::DEVICE_CPU* cpu_ctx = {}; base_device::AbacusDevice_t device = {}; + using ct_Device = typename ct::PsiToContainer::type; using setmem_complex_op = base_device::memory::set_memory_op; using setmem_real_op = base_device::memory::set_memory_op; + using setmem_real_cpu_op = base_device::memory::set_memory_op; using resmem_complex_op = base_device::memory::resize_memory_op; using delmem_complex_op = base_device::memory::delete_memory_op; using syncmem_complex_op = base_device::memory::synchronize_memory_op; using resmem_real_op = base_device::memory::resize_memory_op; using delmem_real_op = base_device::memory::delete_memory_op; using gemm_complex_op = ModuleBase::gemm_op; + using axpy_complex_op = ModuleBase::axpy_op; using vec_add_vec_complex_op = ModuleBase::vector_add_vector_op; using dot_op = ModuleBase::dot_real_op; + using syncmem_complex_c2d_op = base_device::memory::synchronize_memory_op; + using syncmem_complex_d2c_op = base_device::memory::synchronize_memory_op; + using syncmem_real_c2d_op = base_device::memory::synchronize_memory_op; + using syncmem_real_d2c_op = base_device::memory::synchronize_memory_op; + using lapack_potrf = container::kernels::lapack_potrf; + using lapack_trtri = container::kernels::lapack_trtri; bool gamma_extrapolation = true;