Skip to content

Commit c53f445

Browse files
Refactor:Remove update_tau_pos in ucell (#5783)
* modify periodic_boundary_adjustment * modify update_pos_tau * update compile * delete ucell referenc in update_pos_tau * add unittest for update_pos_tau * move back test file * use EXPECT_THAT instead of EXPECT_EQ in relax_old and use regex to remove the title * remove the bug in the relax_old for it didn't run update_pos * [pre-commit.ci lite] apply automatic fixes --------- Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
1 parent 7ae18a5 commit c53f445

24 files changed

+301
-137
lines changed

source/Makefile.Objects

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ OBJS_CELL=atom_pseudo.o\
188188
cell_index.o\
189189
check_atomic_stru.o\
190190
update_cell.o\
191+
bcast_cell.o\
191192

192193
OBJS_DEEPKS=LCAO_deepks.o\
193194
deepks_force.o\

source/module_cell/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ add_library(
2424
cell_index.cpp
2525
check_atomic_stru.cpp
2626
update_cell.cpp
27+
bcast_cell.cpp
2728
)
2829

2930
if(ENABLE_COVERAGE)

source/module_cell/bcast_cell.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#include "unitcell.h"
2+
3+
namespace unitcell
4+
{
5+
void bcast_atoms_tau(Atom* atoms,
6+
const int ntype)
7+
{
8+
#ifdef __MPI
9+
MPI_Barrier(MPI_COMM_WORLD);
10+
for (int i = 0; i < ntype; i++) {
11+
atoms[i].bcast_atom(); // bcast tau array
12+
}
13+
#endif
14+
}
15+
}

source/module_cell/bcast_cell.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#ifndef BCAST_CELL_H
2+
#define BCAST_CELL_H
3+
4+
namespace unitcell
5+
{
6+
void bcast_atoms_tau(Atom* atoms,
7+
const int ntype);
8+
}
9+
10+
#endif // BCAST_CELL_H

source/module_cell/test/CMakeLists.txt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ install(FILES unitcell_test_parallel.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
1414

1515
list(APPEND cell_simple_srcs
1616
../unitcell.cpp
17+
../update_cell.cpp
18+
../bcast_cell.cpp
1719
../read_atoms.cpp
1820
../atom_spec.cpp
1921
../atom_pseudo.cpp
@@ -103,14 +105,14 @@ add_test(NAME cell_parallel_kpoints_test
103105
AddTest(
104106
TARGET cell_unitcell_test
105107
LIBS parameter ${math_libs} base device cell_info symmetry
106-
SOURCES unitcell_test.cpp ../../module_io/output.cpp ../../module_elecstate/cal_ux.cpp ../update_cell.cpp
108+
SOURCES unitcell_test.cpp ../../module_io/output.cpp ../../module_elecstate/cal_ux.cpp
107109

108110
)
109111

110112
AddTest(
111113
TARGET cell_unitcell_test_readpp
112114
LIBS parameter ${math_libs} base device cell_info
113-
SOURCES unitcell_test_readpp.cpp ../../module_io/output.cpp
115+
SOURCES unitcell_test_readpp.cpp ../../module_io/output.cpp
114116
)
115117

116118
AddTest(
@@ -123,7 +125,6 @@ AddTest(
123125
TARGET cell_unitcell_test_setupcell
124126
LIBS parameter ${math_libs} base device cell_info
125127
SOURCES unitcell_test_setupcell.cpp ../../module_io/output.cpp
126-
../../module_cell/update_cell.cpp
127128
)
128129

129130
add_test(NAME cell_unitcell_test_parallel

source/module_cell/test/support/mock_unitcell.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,9 @@ bool UnitCell::read_atom_positions(std::ifstream& ifpos,
3333
std::ofstream& ofs_warning) {
3434
return true;
3535
}
36-
void UnitCell::update_pos_tau(const double* pos) {}
3736
void UnitCell::update_pos_taud(double* posd_in) {}
3837
void UnitCell::update_pos_taud(const ModuleBase::Vector3<double>* posd_in) {}
3938
void UnitCell::update_vel(const ModuleBase::Vector3<double>* vel_in) {}
40-
void UnitCell::periodic_boundary_adjustment() {}
4139
void UnitCell::bcast_atoms_tau() {}
4240
bool UnitCell::judge_big_cell() const { return true; }
4341
void UnitCell::update_stress(ModuleBase::matrix& scs) {}

source/module_cell/test/unitcell_test.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -783,7 +783,9 @@ TEST_F(UcellDeathTest, PeriodicBoundaryAdjustment1)
783783
PARAM.input.relax_new = utp.relax_new;
784784
ucell = utp.SetUcellInfo();
785785
testing::internal::CaptureStdout();
786-
EXPECT_EXIT(ucell->periodic_boundary_adjustment(), ::testing::ExitedWithCode(1), "");
786+
EXPECT_EXIT(unitcell::periodic_boundary_adjustment(
787+
ucell->atoms,ucell->latvec,ucell->ntype),
788+
::testing::ExitedWithCode(1), "");
787789
std::string output = testing::internal::GetCapturedStdout();
788790
EXPECT_THAT(output, testing::HasSubstr("the movement of atom is larger than the length of cell"));
789791
}
@@ -793,7 +795,8 @@ TEST_F(UcellTest, PeriodicBoundaryAdjustment2)
793795
UcellTestPrepare utp = UcellTestLib["C1H2-Index"];
794796
PARAM.input.relax_new = utp.relax_new;
795797
ucell = utp.SetUcellInfo();
796-
EXPECT_NO_THROW(ucell->periodic_boundary_adjustment());
798+
EXPECT_NO_THROW(unitcell::periodic_boundary_adjustment(
799+
ucell->atoms,ucell->latvec,ucell->ntype));
797800
}
798801

799802
TEST_F(UcellTest, PrintCell)

source/module_cell/test/unitcell_test_para.cpp

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
#include "mpi.h"
1515
#endif
1616
#include "prepare_unitcell.h"
17-
17+
#include "../update_cell.h"
18+
#include "../bcast_cell.h"
1819
#ifdef __LCAO
1920
InfoNonlocal::InfoNonlocal()
2021
{
@@ -44,6 +45,7 @@ Magnetism::~Magnetism()
4445
/**
4546
* - Tested Functions:
4647
* - UpdatePosTaud
48+
* - update_pos_tau(double* pos)
4749
* - update_pos_taud(const double* pos)
4850
* - bcast_atoms_tau() is also called in the above function, which calls Atom::bcast_atom with many
4951
* atomic info in addition to tau
@@ -123,7 +125,34 @@ TEST_F(UcellTest, BcastUnitcell)
123125
EXPECT_EQ(atom_labels[1], atom_type2_expected);
124126
}
125127
}
126-
128+
TEST_F(UcellTest, UpdatePosTau)
129+
{
130+
double* pos_in = new double[ucell->nat * 3];
131+
ucell->set_iat2itia();
132+
std::fill(pos_in, pos_in + ucell->nat * 3, 0);
133+
for (int iat = 0; iat < ucell->nat; ++iat)
134+
{
135+
int it, ia;
136+
ucell->iat2iait(iat, &ia, &it);
137+
for (int ik = 0; ik < 3; ++ik)
138+
{
139+
ucell->atoms[it].mbl[ia][ik] = true;
140+
pos_in[iat * 3 + ik] = (iat * 3 + ik) / (ucell->nat * 3.0) * (ucell->lat.lat0);
141+
}
142+
}
143+
unitcell::update_pos_tau(ucell->lat,pos_in,ucell->ntype,ucell->nat,ucell->atoms);
144+
for (int iat = 0; iat < ucell->nat; ++iat)
145+
{
146+
int it, ia;
147+
ucell->iat2iait(iat, &ia, &it);
148+
for (int ik = 0; ik < 3; ++ik)
149+
{
150+
EXPECT_DOUBLE_EQ(ucell->atoms[it].tau[ia][ik],
151+
(iat*3+ik)/(ucell->nat*3.0));
152+
}
153+
}
154+
delete[] pos_in;
155+
}
127156
TEST_F(UcellTest, UpdatePosTaud)
128157
{
129158
double* pos_in = new double[ucell->nat * 3];
@@ -147,6 +176,7 @@ TEST_F(UcellTest, UpdatePosTaud)
147176
EXPECT_DOUBLE_EQ(ucell->atoms[it].taud[ia].y, tmp[iat].y + 0.01);
148177
EXPECT_DOUBLE_EQ(ucell->atoms[it].taud[ia].z, tmp[iat].z + 0.01);
149178
}
179+
delete[] tmp;
150180
delete[] pos_in;
151181
}
152182

source/module_cell/test_pw/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ install(FILES unitcell_test_pw_para.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
1111
AddTest(
1212
TARGET cell_unitcell_test_pw
1313
LIBS parameter ${math_libs} base device
14-
SOURCES unitcell_test_pw.cpp ../unitcell.cpp ../read_atoms.cpp ../atom_spec.cpp
14+
SOURCES unitcell_test_pw.cpp ../unitcell.cpp ../read_atoms.cpp ../atom_spec.cpp ../update_cell.cpp ../bcast_cell.cpp
1515
../atom_pseudo.cpp ../pseudo.cpp ../read_pp.cpp ../read_pp_complete.cpp ../read_pp_upf201.cpp ../read_pp_upf100.cpp
1616
../read_pp_vwr.cpp ../read_pp_blps.cpp ../../module_io/output.cpp ../../module_elecstate/read_pseudo.cpp ../../module_elecstate/cal_nelec_nband.cpp
1717
)

source/module_cell/unitcell.cpp

Lines changed: 4 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
#include "module_ri/serialization_cereal.h"
2828
#endif
2929

30+
31+
#include "update_cell.h"
3032
UnitCell::UnitCell() {
3133
if (test_unitcell) {
3234
ModuleBase::TITLE("unitcell", "Constructor");
@@ -312,29 +314,7 @@ std::vector<ModuleBase::Vector3<int>> UnitCell::get_constrain() const
312314
return constrain;
313315
}
314316

315-
void UnitCell::update_pos_tau(const double* pos) {
316-
int iat = 0;
317-
for (int it = 0; it < this->ntype; it++) {
318-
Atom* atom = &this->atoms[it];
319-
for (int ia = 0; ia < atom->na; ia++) {
320-
for (int ik = 0; ik < 3; ++ik) {
321-
if (atom->mbl[ia][ik]) {
322-
atom->dis[ia][ik]
323-
= pos[3 * iat + ik] / this->lat0 - atom->tau[ia][ik];
324-
atom->tau[ia][ik] = pos[3 * iat + ik] / this->lat0;
325-
}
326-
}
327317

328-
// the direct coordinates also need to be updated.
329-
atom->dis[ia] = atom->dis[ia] * this->GT;
330-
atom->taud[ia] = atom->tau[ia] * this->GT;
331-
iat++;
332-
}
333-
}
334-
assert(iat == this->nat);
335-
this->periodic_boundary_adjustment();
336-
this->bcast_atoms_tau();
337-
}
338318

339319
void UnitCell::update_pos_taud(double* posd_in) {
340320
int iat = 0;
@@ -349,7 +329,7 @@ void UnitCell::update_pos_taud(double* posd_in) {
349329
}
350330
}
351331
assert(iat == this->nat);
352-
this->periodic_boundary_adjustment();
332+
unitcell::periodic_boundary_adjustment(this->atoms,this->latvec, this->ntype);
353333
this->bcast_atoms_tau();
354334
}
355335

@@ -367,7 +347,7 @@ void UnitCell::update_pos_taud(const ModuleBase::Vector3<double>* posd_in) {
367347
}
368348
}
369349
assert(iat == this->nat);
370-
this->periodic_boundary_adjustment();
350+
unitcell::periodic_boundary_adjustment(this->atoms,this->latvec, this->ntype);
371351
this->bcast_atoms_tau();
372352
}
373353

@@ -383,54 +363,6 @@ void UnitCell::update_vel(const ModuleBase::Vector3<double>* vel_in) {
383363
assert(iat == this->nat);
384364
}
385365

386-
void UnitCell::periodic_boundary_adjustment() {
387-
//----------------------------------------------
388-
// because of the periodic boundary condition
389-
// we need to adjust the atom positions,
390-
// first adjust direct coordinates,
391-
// then update them into cartesian coordinates,
392-
//----------------------------------------------
393-
for (int it = 0; it < this->ntype; it++) {
394-
Atom* atom = &this->atoms[it];
395-
for (int ia = 0; ia < atom->na; ia++) {
396-
// mohan update 2011-03-21
397-
if (atom->taud[ia].x < 0) {
398-
atom->taud[ia].x += 1.0;
399-
}
400-
if (atom->taud[ia].y < 0) {
401-
atom->taud[ia].y += 1.0;
402-
}
403-
if (atom->taud[ia].z < 0) {
404-
atom->taud[ia].z += 1.0;
405-
}
406-
if (atom->taud[ia].x >= 1.0) {
407-
atom->taud[ia].x -= 1.0;
408-
}
409-
if (atom->taud[ia].y >= 1.0) {
410-
atom->taud[ia].y -= 1.0;
411-
}
412-
if (atom->taud[ia].z >= 1.0) {
413-
atom->taud[ia].z -= 1.0;
414-
}
415-
416-
if (atom->taud[ia].x < 0 || atom->taud[ia].y < 0
417-
|| atom->taud[ia].z < 0 || atom->taud[ia].x >= 1.0
418-
|| atom->taud[ia].y >= 1.0 || atom->taud[ia].z >= 1.0) {
419-
GlobalV::ofs_warning << " it=" << it + 1 << " ia=" << ia + 1
420-
<< std::endl;
421-
GlobalV::ofs_warning << "d=" << atom->taud[ia].x << " "
422-
<< atom->taud[ia].y << " "
423-
<< atom->taud[ia].z << std::endl;
424-
ModuleBase::WARNING_QUIT(
425-
"Ions_Move_Basic::move_ions",
426-
"the movement of atom is larger than the length of cell.");
427-
}
428-
429-
atom->tau[ia] = atom->taud[ia] * this->latvec;
430-
}
431-
}
432-
return;
433-
}
434366

435367
void UnitCell::bcast_atoms_tau() {
436368
#ifdef __MPI

0 commit comments

Comments
 (0)