Skip to content

Commit

Permalink
Fix rsdf due to the basis sorting in 923d84c (issue 1942)
Browse files Browse the repository at this point in the history
  • Loading branch information
sunqm committed Nov 9, 2023
1 parent 14d8882 commit 84916aa
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 18 deletions.
7 changes: 2 additions & 5 deletions pyscf/gto/mole.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,7 +1023,7 @@ def make_env(atoms, basis, pre_env=[], nucmod={}, nucprop={}):
'''
_atm = []
_bas = []
_env = []
_env = [pre_env]
ptr_env = len(pre_env)

for ia, atom in enumerate(atoms):
Expand Down Expand Up @@ -1092,10 +1092,7 @@ def make_env(atoms, basis, pre_env=[], nucmod={}, nucprop={}):
_bas = numpy.asarray(numpy.vstack(_bas), numpy.int32).reshape(-1, BAS_SLOTS)
else:
_bas = numpy.zeros((0,BAS_SLOTS), numpy.int32)
if _env:
_env = numpy.hstack((pre_env,numpy.hstack(_env)))
else:
_env = numpy.array(pre_env, copy=False)
_env = numpy.asarray(numpy.hstack(_env), dtype=numpy.float64)
return _atm, _bas, _env

def make_ecp_env(mol, _atm, ecp, pre_env=[]):
Expand Down
27 changes: 15 additions & 12 deletions pyscf/pbc/df/rsdf_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,21 @@
"""
class MoleNoBasSort(mol_gto.mole.Mole):
def build(self, **kwargs):
mol_gto.mole.Mole.build(self, **kwargs)

# sort self._bas
_bas = []
for iatm in range(self.natm):
atm = self.atom_symbol(iatm)
basin = self._basis[atm]
bas_ls = [b[0] for b in basin]
ls_uniq, ls_inv = np.unique(bas_ls, return_inverse=True)
order = np.argsort(np.concatenate([np.where(ls_inv==l)[0] for l in ls_uniq]))
_bas.append( self._bas[np.where(self._bas[:,0] == iatm)[0][order]] )
self._bas = np.vstack(_bas).astype(np.int32)
self.atom = kwargs.pop('atom')
self.basis = kwargs.pop('basis')
self._atom = mol_gto.format_atom(self.atom)
if isinstance(self.basis, dict):
self._basis = self.basis
else:
self._basis = {a[0]: self.basis for a in self._atom}

env = np.zeros(mol_gto.PTR_ENV_START)
# _bas should be constructed as it is in the input (see issue #1942).
self._atm, self._bas, self._env = self.make_env(
self._atom, self._basis, env)
self._built = True
return self

def _remove_exp_basis_(bold, amin, amax):
bnew = []
for b in bold:
Expand Down
24 changes: 23 additions & 1 deletion pyscf/pbc/df/test/test_rsdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pyscf import ao2mo
from pyscf.pbc import gto as pgto
from pyscf.pbc import scf as pscf
from pyscf.pbc.df import rsdf
from pyscf.pbc.df import rsdf, df
#from mpi4pyscf.pbc.df import df
pyscf.pbc.DEBUG = False

Expand Down Expand Up @@ -124,6 +124,28 @@ def test_get_eri_0123(self):
self.assertAlmostEqual(abs(eri0123.imag.sum()), 4.9887958509e-5, 7)
self.assertAlmostEqual(lib.fp(eri0123), 0.9695261296288074-0.33222740818370966j, 8)

def test_rsdf_build(self):
cell = pgto.M(a=numpy.eye(3)*1.8,
atom='''Li 0. 0. 0.; H 0. .5 1.2 ''',
basis={'Li': [[0, [5., 1.]], [0, [.6, 1.]], [1, [3., 1.]]],
'H': [[0, [.3, 1.]]]})
auxbasis = {'Li': [[0, [5., 1.]], [0, [1.5, 1.]], [1, [.5, 1.]], [2, [2.5, 1.]]],
'H': [[0, [2., 1.]]]}
numpy.random.seed(2)
dm = numpy.random.random([cell.nao]*2)

gdf = df.GDF(cell)
gdf.auxbasis = auxbasis
jref, kref = gdf.get_jk(dm)

gdf = df.RSGDF(cell)
gdf.auxbasis = auxbasis
vj, vk = gdf.get_jk(dm)

self.assertAlmostEqual(abs(vj - jref).max(), 0, 7)
self.assertAlmostEqual(abs(vk - kref).max(), 0, 7)
self.assertAlmostEqual(lib.fp(vj), 2.383648833459583, 7)
self.assertAlmostEqual(lib.fp(vk), 1.553598328349400, 7)

if __name__ == '__main__':
print("Full Tests for rsdf")
Expand Down

0 comments on commit 84916aa

Please sign in to comment.