Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ void hamilt::DeePKS<hamilt::OperatorLCAO<TK, TR>>::contributeHR()
this->ld->lmaxd,
this->ld->inl2l,
this->ld->inl_index,
this->kvec_d,
this->DM,
this->ld->phialpha,
*this->ucell,
Expand Down
34 changes: 32 additions & 2 deletions source/module_hamilt_lcao/module_deepks/LCAO_deepks_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
using TH = std::conditional_t<std::is_same<TK, double>::value, ModuleBase::matrix, ModuleBase::ComplexMatrix>;

// These variables are frequently used in the following code
const int inlmax = orb.Alpha[0].getTotal_nchi() * nat;
const int nlmax = orb.Alpha[0].getTotal_nchi();
const int inlmax = nlmax * nat;
const int lmaxd = orb.get_lmax_d();
const int nmaxd = ld->nmaxd;

Expand All @@ -62,7 +63,7 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
// this part is for integrated test of deepks
// so it is printed no matter even if deepks_out_labels is not used
DeePKS_domain::cal_pdm<
TK>(init_pdm, inlmax, lmaxd, inl2l, inl_index, dm, phialpha, ucell, orb, GridD, *ParaV, pdm);
TK>(init_pdm, inlmax, lmaxd, inl2l, inl_index, kvec_d, dm, phialpha, ucell, orb, GridD, *ParaV, pdm);

DeePKS_domain::check_pdm(inlmax, inl2l, pdm); // print out the projected dm for NSCF calculaiton

Expand Down Expand Up @@ -312,6 +313,35 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
out_hr.write();
ofs_hr.close();
}

const std::string file_vdrpre = PARAM.globalv.global_out_dir + "deepks_vdrpre.csr";
std::vector<hamilt::HContainer<TR>*> h_deltaR_pre(inlmax);
for (int i = 0; i < inlmax; i++)
{
h_deltaR_pre[i] = new hamilt::HContainer<TR>(*hR_tot);
h_deltaR_pre[i]->set_zero();
}
// DeePKS_domain::cal_vdr_precalc<TR>();
if (rank == 0)
{
std::ofstream ofs_hrp(file_vdrpre, std::ios::out);
for (int iat = 0; iat < nat; iat++)
{
ofs_hrp << "- Index of atom: " << iat << std::endl;
for (int nl = 0; nl < nlmax; nl++)
{
int inl = iat * nlmax + nl;
ofs_hrp << "-- Index of nl: " << nl << std::endl;
ofs_hrp << "Matrix Dimension of H_delta(R): " << h_deltaR_pre[inl]->get_nbasis() << std::endl;
ofs_hrp << "Matrix number of H_delta(R): " << h_deltaR_pre[inl]->size_R_loop() << std::endl;
hamilt::Output_HContainer<TR> out_hrp(h_deltaR_pre[inl], ofs_hrp, sparse_threshold, precision);
out_hrp.write();
ofs_hrp << std::endl;
}
ofs_hrp << std::endl;
}
ofs_hrp.close();
}
}
}

Expand Down
91 changes: 46 additions & 45 deletions source/module_hamilt_lcao/module_deepks/deepks_pdm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
const int lmaxd,
const std::vector<int>& inl2l,
const ModuleBase::IntArray* inl_index,
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
const elecstate::DensityMatrix<TK, double>* dm,
const std::vector<hamilt::HContainer<double>*> phialpha,
const UnitCell& ucell,
Expand Down Expand Up @@ -231,7 +232,7 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
}
}

for (int ad2 = 0; ad2 < adjs.adj_num + 1; ad2++)
for (int ad2 = 0; ad2 < adjs.adj_num + 1; ad2++)
{
const int T2 = adjs.ntype[ad2];
const int I2 = adjs.natom[ad2];
Expand Down Expand Up @@ -274,33 +275,31 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
// prepare DM from DMR
std::vector<double> dm_array(row_size * col_size, 0.0);
const double* dm_current = nullptr;
for (int is = 0; is < dm->get_DMR_vector().size(); is++)
int dRx = 0, dRy = 0, dRz = 0;
if constexpr (std::is_same<TK, std::complex<double>>::value)
{
int dRx = 0, dRy = 0, dRz = 0;
if constexpr (std::is_same<TK, std::complex<double>>::value)
{
dRx = dR2.x - dR1.x;
dRy = dR2.y - dR1.y;
dRz = dR2.z - dR1.z;
}
// dm_R
auto* tmp = dm->get_DMR_vector()[is]->find_matrix(ibt1, ibt2, dRx, dRy, dRz);
if (tmp == nullptr)
{
// in case of no deepks_scf but out_deepks_label, size of DMR would mismatch with
// deepks-orbitals
dm_current = nullptr;
break;
}
dm_current = tmp->get_pointer();
for (int idm = 0; idm < row_size * col_size; idm++)
{
dm_array[idm] += dm_current[idm];
}
dRx = dR2.x - dR1.x;
dRy = dR2.y - dR1.y;
dRz = dR2.z - dR1.z;
}
if (dm_current == nullptr)
// dm_k
auto dm_k = dm->get_DMK_vector();
const int nrow = pv.nrow;
for (int ir = 0; ir < row_size; ir++)
{
continue; // skip the long range DM pair more than nonlocal term
for (int ic = 0; ic < col_size; ic++)
{
int iglob = (pv.atom_begin_row[ibt1] + ir) + nrow * (pv.atom_begin_col[ibt2] + ic);
int iloc = ir * col_size + ic;
std::complex<double> tmp = 0.0;
for(int ik = 0; ik < dm_k.size(); ik++) // dm_k.size() == _nk * _nspin
{
const double arg = (kvec_d[ik] * ModuleBase::Vector3<double>(dR1 - dR2)) * ModuleBase::TWO_PI;
const std::complex<double> kphase = std::complex<double>(cos(arg), sin(arg));
tmp += dm_k[ik][iglob] * kphase;
}
dm_array[iloc] += tmp.real();
}
}

dm_current = dm_array.data();
Expand All @@ -311,18 +310,18 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
constexpr char transa = 'T', transb = 'N';
const double gemm_alpha = 1.0, gemm_beta = 1.0;
dgemm_(&transa,
&transb,
&row_size,
&trace_alpha_size,
&col_size,
&gemm_alpha,
dm_current,
&col_size,
s_2t.data(),
&col_size,
&gemm_beta,
g_1dmt.data(),
&row_size);
&transb,
&row_size,
&trace_alpha_size,
&col_size,
&gemm_alpha,
dm_current,
&col_size,
s_2t.data(),
&col_size,
&gemm_beta,
g_1dmt.data(),
&row_size);
} // ad2
if (!PARAM.inp.deepks_equiv)
{
Expand All @@ -340,10 +339,10 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
for (int m2 = 0; m2 < nm; ++m2) // m1 = 1 for s, 3 for p, 5 for d
{
accessor[m1][m2] += ddot_(&row_size,
g_1dmt.data() + index * row_size,
&inc,
s_1t.data() + index * row_size,
&inc);
g_1dmt.data() + index * row_size,
&inc,
s_1t.data() + index * row_size,
&inc);
index++;
}
}
Expand All @@ -366,10 +365,10 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
// ddot_: dot product of two vectors
// inc means the increment of the index
accessor[iproj * nproj + jproj] += ddot_(&row_size,
g_1dmt.data() + index * row_size,
&inc,
s_1t.data() + index * row_size,
&inc);
g_1dmt.data() + index * row_size,
&inc,
s_1t.data() + index * row_size,
&inc);
index++;
}
}
Expand Down Expand Up @@ -414,6 +413,7 @@ template void DeePKS_domain::cal_pdm<double>(bool& init_pdm,
const int lmaxd,
const std::vector<int>& inl2l,
const ModuleBase::IntArray* inl_index,
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
const elecstate::DensityMatrix<double, double>* dm,
const std::vector<hamilt::HContainer<double>*> phialpha,
const UnitCell& ucell,
Expand All @@ -428,6 +428,7 @@ template void DeePKS_domain::cal_pdm<std::complex<double>>(
const int lmaxd,
const std::vector<int>& inl2l,
const ModuleBase::IntArray* inl_index,
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
const elecstate::DensityMatrix<std::complex<double>, double>* dm,
const std::vector<hamilt::HContainer<double>*> phialpha,
const UnitCell& ucell,
Expand Down
1 change: 1 addition & 0 deletions source/module_hamilt_lcao/module_deepks/deepks_pdm.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ void cal_pdm(bool& init_pdm,
const int lmaxd,
const std::vector<int>& inl2l,
const ModuleBase::IntArray* inl_index,
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
const elecstate::DensityMatrix<TK, double>* dm,
const std::vector<hamilt::HContainer<double>*> phialpha,
const UnitCell& ucell,
Expand Down
6 changes: 3 additions & 3 deletions source/module_hamilt_lcao/module_deepks/deepks_vdpre.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ void DeePKS_domain::cal_v_delta_precalc(const int nlocal,
const Grid_Driver& GridD,
torch::Tensor& v_delta_precalc)
{
ModuleBase::TITLE("DeePKS_domain", "calc_v_delta_precalc");
ModuleBase::timer::tick("DeePKS_domain", "calc_v_delta_precalc");
ModuleBase::TITLE("DeePKS_domain", "cal_v_delta_precalc");
ModuleBase::timer::tick("DeePKS_domain", "cal_v_delta_precalc");
// timeval t_start;
// gettimeofday(&t_start,NULL);

Expand Down Expand Up @@ -230,7 +230,7 @@ void DeePKS_domain::cal_v_delta_precalc(const int nlocal,
// std::cout<<"calculate v_delta_precalc time:\t"<<(double)(t_end.tv_sec-t_start.tv_sec) +
// (double)(t_end.tv_usec-t_start.tv_usec)/1000000.0<<std::endl;

ModuleBase::timer::tick("DeePKS_domain", "calc_v_delta_precalc");
ModuleBase::timer::tick("DeePKS_domain", "cal_v_delta_precalc");
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ void test_deepks<T>::check_pdm()
this->ld.lmaxd,
this->ld.inl2l,
this->ld.inl_index,
kv.kvec_d,
p_elec_DM,
this->ld.phialpha,
ucell,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ Output_HContainer<T>::Output_HContainer(hamilt::HContainer<T>* hcontainer,
int precision)
: _hcontainer(hcontainer), _ofs(ofs), _sparse_threshold(sparse_threshold), _precision(precision)
{
if (this->_sparse_threshold == -1)
{
this->_sparse_threshold = 1e-10;
}
if (this->_precision == -1)
{
this->_precision = 8;
}
}

template <typename T>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ template <typename T>
class Output_HContainer
{
public:
Output_HContainer(hamilt::HContainer<T>* hcontainer, std::ostream& ofs, double sparse_threshold, int precision);
Output_HContainer(hamilt::HContainer<T>* hcontainer, std::ostream& ofs, double sparse_threshold = -1, int precision = -1);
// write the matrices of all R vectors to the output stream
void write();

Expand Down
18 changes: 9 additions & 9 deletions tests/deepks/602_NO_deepks_d_H2O_md_lda2pbe/result.ref
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
etotref -465.9986234579913
etotperatomref -155.3328744860
totalforceref 5.535112
totalstressref 1.522354
totaldes 2.163682
deepks_e_dm -57.88576364957592
deepks_f_label 19.095631983991726
deepks_s_label 19.250613228828858
totaltimeref 22.06
etotref -465.9986233931722
etotperatomref -155.3328744644
totalforceref 5.535484
totalstressref 1.522431
totaldes 2.163702
deepks_e_dm -57.88572052925276
deepks_f_label 19.09562238352583
deepks_s_label 19.250613977989474
totaltimeref 14.90
12 changes: 6 additions & 6 deletions tests/deepks/603_NO_deepks_H2O_bandgap/result.ref
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
etotref -466.8189388994859
etotperatomref -155.6063129665
etotref -466.8189389058687
etotperatomref -155.6063129686
totaldes 4.392987
deepks_e_dm -49.145154045309184
odelta 0.05196672366779986
oprec 0.3729036159496012
totaltimeref 11.14
deepks_e_dm -49.145154051850646
odelta 0.05196672144400882
oprec 0.3729036136469043
totaltimeref 11.03
8 changes: 4 additions & 4 deletions tests/deepks/603_NO_deepks_H2O_multik/result.ref
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
etotref -466.8999964506085
etotperatomref -155.6333321502
totalforceref 10.047085
totaltimeref 21.21
etotref -466.8143178204656
etotperatomref -155.6047726068
totalforceref 10.244182
totaltimeref 7.52
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.00123076105
0.009421273749
2 changes: 1 addition & 1 deletion tests/deepks/604_NO_deepks_ut_H2O_multik/E_delta_ref.dat
Original file line number Diff line number Diff line change
@@ -1 +1 @@
-0.1964491488
-0.1883243273
6 changes: 3 additions & 3 deletions tests/deepks/604_NO_deepks_ut_H2O_multik/F_delta_ref.dat
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
0.01002305073 0.001383599869 -0.008194913779
-0.006744232782 -0.004657772796 0.0005568916214
-0.00327881795 0.003274172927 0.007638022158
0.002181260802 2.610262729e-05 -0.002150222026
-0.001268781011 -0.0007629016191 0.0002663644433
-0.000912479791 0.0007367989918 0.001883857582
6 changes: 3 additions & 3 deletions tests/deepks/604_NO_deepks_ut_H2O_multik/STRU
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ LATTICE_CONSTANT
1

LATTICE_VECTORS
10 0 0
0 10 0
0 0 10
15 0 0
0 15 0
0 0 15

ATOMIC_POSITIONS
Direct # Cartesian(Unit is LATTICE_CONSTANT)
Expand Down
Loading
Loading