Skip to content

Commit

Permalink
Add VHFOpt for LRDF
Browse files Browse the repository at this point in the history
  • Loading branch information
sunqm committed Sep 29, 2023
1 parent a205750 commit ae1bc17
Show file tree
Hide file tree
Showing 8 changed files with 735 additions and 270 deletions.
416 changes: 404 additions & 12 deletions pyscf/lib/vhf/nr_sr_vhf.c

Large diffs are not rendered by default.

45 changes: 28 additions & 17 deletions pyscf/lib/vhf/optimizer.c
Original file line number Diff line number Diff line change
Expand Up @@ -400,9 +400,9 @@ void CVHFsetnr_direct_scf(CVHFOpt *opt, int (*intor)(), CINTOpt *cintopt,
/*
* Non-relativistic 2-electron integrals
*/
void CVHFset_int2e_q_cond(int (*intor)(), CINTOpt *cintopt, double *q_cond,
int *ao_loc, int *atm, int natm,
int *bas, int nbas, double *env)
void CVHFnr_int2e_q_cond(int (*intor)(), CINTOpt *cintopt, double *q_cond,
int *ao_loc, int *atm, int natm,
int *bas, int nbas, double *env)
{
int shls_slice[] = {0, nbas};
const int cache_size = GTOmax_cache_size(intor, shls_slice, 1,
Expand Down Expand Up @@ -448,6 +448,13 @@ void CVHFset_int2e_q_cond(int (*intor)(), CINTOpt *cintopt, double *q_cond,
}
}

void CVHFset_int2e_q_cond(int (*intor)(), CINTOpt *cintopt, double *q_cond,
int *ao_loc, int *atm, int natm,
int *bas, int nbas, double *env)
{
CVHFnr_int2e_q_cond(intor, cintopt, q_cond, ao_loc, atm, natm, bas, nbas, env);
}

void CVHFset_q_cond(CVHFOpt *opt, double *q_cond, int len)
{
if (opt->q_cond != NULL) {
Expand All @@ -457,19 +464,10 @@ void CVHFset_q_cond(CVHFOpt *opt, double *q_cond, int len)
NPdcopy(opt->q_cond, q_cond, len);
}

void CVHFsetnr_direct_scf_dm(CVHFOpt *opt, double *dm, int nset, int *ao_loc,
int *atm, int natm, int *bas, int nbas, double *env)
void CVHFnr_dm_cond(double *dm_cond, double *dm, int nset, int *ao_loc,
int *atm, int natm, int *bas, int nbas, double *env)
{
if (opt->dm_cond != NULL) { // NOT reuse opt->dm_cond because nset may be diff in different call
free(opt->dm_cond);
}
// nbas in the input arguments may different to opt->nbas.
// Use opt->nbas because it is used in the prescreen function
nbas = opt->nbas;
opt->dm_cond = (double *)malloc(sizeof(double) * nbas*nbas);
NPdset0(opt->dm_cond, ((size_t)nbas)*nbas);

const size_t nao = ao_loc[nbas];
size_t nao = ao_loc[nbas];
double dmax, tmp;
size_t i, j, ish, jsh, iset;
double *pdm;
Expand All @@ -487,11 +485,24 @@ void CVHFsetnr_direct_scf_dm(CVHFOpt *opt, double *dm, int nset, int *ao_loc,
dmax = MAX(dmax, tmp);
} }
}
opt->dm_cond[ish*nbas+jsh] = .5 * dmax;
opt->dm_cond[jsh*nbas+ish] = .5 * dmax;
dm_cond[ish*nbas+jsh] = .5 * dmax;
dm_cond[jsh*nbas+ish] = .5 * dmax;
} }
}

void CVHFsetnr_direct_scf_dm(CVHFOpt *opt, double *dm, int nset, int *ao_loc,
int *atm, int natm, int *bas, int nbas, double *env)
{
if (opt->dm_cond != NULL) { // NOT reuse opt->dm_cond because nset may be diff in different call
free(opt->dm_cond);
}
// nbas in the input arguments may different to opt->nbas.
// Use opt->nbas because it is used in the prescreen function
nbas = opt->nbas;
opt->dm_cond = (double *)malloc(sizeof(double) * nbas*nbas);
CVHFnr_dm_cond(opt->dm_cond, dm, nset, ao_loc, atm, natm, bas, nbas, env);
}

void CVHFset_dm_cond(CVHFOpt *opt, double *dm_cond, int len)
{
if (opt->dm_cond != NULL) {
Expand Down
24 changes: 19 additions & 5 deletions pyscf/lrdf/grad/rhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
# limitations under the License.
#

import numpy as np
from pyscf import lib
from pyscf.scf import _vhf
from pyscf.lib import logger
from pyscf.lrdf import lrdf
from pyscf.grad import rhf as rhf_grad
Expand All @@ -35,17 +37,29 @@ def get_jk(self, mol=None, dm=None, hermi=0, omega=None):

lrdf_obj = self.base.with_df
omega = lrdf_obj.omega
vj, vk = rhf_grad.Gradients.get_jk(self, mol, dm, hermi, -omega)
with mol.with_range_coulomb(-omega):
vj, vk = rhf_grad.get_jk(mol, dm)
# TODO: initialize q_cond with CVHFgrad_jk_direct_scf
#vhfopt = lrdf._VHFOpt(mol, 'int2e_ip1',
# prescreen='CVHFgrad_jk_prescreen', omega=omega)
vhfopt = lrdf._VHFOpt(mol, 'int2e_ip1', omega=omega)
vhfopt._this.q_cond = lrdf_obj._vhfopt._this.q_cond
vhfopt._this.dm_cond = lrdf_obj._vhfopt._this.dm_cond

with mol.with_short_range_coulomb(omega):
intor = mol._add_suffix('int2e_ip1')
vj, vk = _vhf.direct_mapdm(intor, # (nabla i,j|k,l)
's2kl', # ip1_sph has k>=l,
('lk->s1ij', 'jk->s1il'),
dm, 3, # xyz, 3 components
mol._atm, mol._bas, mol._env, vhfopt=vhfopt,
optimize_sr=True)

with lrdf_obj.range_coulomb(omega):
with lib.temporary_env(lrdf_obj, auxmol=lrdf_obj.lr_auxmol):
vj1, vk1 = df_rhf_grad.get_jk(self, mol, dm, hermi,
decompose_j2c='ED',
lindep=lrdf_obj.lr_thresh)
vj += vj1
vk += vk1
vj = vj1 - np.asarray(vj)
vk = vk1 - np.asarray(vk)
if self.auxbasis_response:
vj = lib.tag_array(vj, aux=vj1.aux)
vk = lib.tag_array(vk, aux=vk1.aux)
Expand Down
154 changes: 63 additions & 91 deletions pyscf/lrdf/lrdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,23 @@
import ctypes
import tempfile
import numpy as np
import scipy.special
from pyscf import gto
from pyscf import lib
from pyscf.lib import logger
from pyscf.df import df, df_jk, outcore, addons
from pyscf.gto import ft_ao
from pyscf.dft.gen_grid import LEBEDEV_NGRID, libdft
from pyscf.gto.moleintor import make_cintopt
from pyscf.pbc.df.incore import libpbc
from pyscf.scf._vhf import libcvhf, _fpointer
from pyscf.scf._vhf import libcvhf
from pyscf.scf import _vhf

MIN_CUTOFF = 1e-44
AUXBASIS = {
#'H': [[0, [1., 1]]],
'default': [[0, [1., 1]], [1, [1., 1]], [2, [1., 1]]]
}

class _CVHFOpt(ctypes.Structure):
_fields_ = [('nbas', ctypes.c_int),
('ngrids', ctypes.c_int),
('log_cutoff', ctypes.c_double),
('logq_cond', ctypes.c_void_p),
('dm_cond', ctypes.c_void_p),
('fprescreen', ctypes.c_void_p),
('r_vkscreen', ctypes.c_void_p)]

class LRDensityFitting(df.DF):

omega = 0.1
Expand All @@ -42,18 +34,15 @@ class LRDensityFitting(df.DF):
lr_dfj = True

def __init__(self, mol, auxbasis=None):
self._intor = 'int2e'
self._cintopt = None
self.q_cond = None
self.lr_auxmol = None
self.wcoulG = None
self.Gv = None
self._vhfopt = None
self._last_vs = (0, 0, 0)
df.DF.__init__(self, mol, auxbasis)

def reset(self, mol=None):
self.q_cond = None
self._cintopt = None
self._vhfopt = None
return df.DF.reset(self, mol)

def dump_flags(self, verbose=None):
Expand Down Expand Up @@ -81,24 +70,12 @@ def build(self):
self.dump_flags()

mol = self.mol
nbas = mol.nbas
self.q_cond = np.empty((6,nbas,nbas), dtype=np.float32)
ao_loc = mol.ao_loc
omega = self.omega
assert omega > 0

with mol.with_short_range_coulomb(omega):
self._cintopt = make_cintopt(
mol._atm, mol._bas, mol._env, self._intor)

with mol.with_integral_screen(self.direct_scf_tol**2):
libpbc.CVHFsetnr_sr_direct_scf(
libpbc.int2e_sph, self._cintopt,
self.q_cond.ctypes.data_as(ctypes.c_void_p),
ao_loc.ctypes.data_as(ctypes.c_void_p),
mol._atm.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(mol.natm),
mol._bas.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(mol.nbas),
mol._env.ctypes.data_as(ctypes.c_void_p))
self._vhfopt = vhfopt = _VHFOpt(mol, 'int2e', omega=omega)
with mol.with_integral_screen(self.direct_scf_tol**2):
vhfopt.init_cvhf_direct(mol)
cpu0 = log.timer('initializing q_cond', *cpu0)

if self.lr_auxmol is None:
Expand Down Expand Up @@ -165,7 +142,7 @@ def get_jk(self, dm, hermi=1, with_j=True, with_k=True,
return vj, vk

def _get_jk_sr(self, dm, hermi=1, with_j=True, with_k=True):
if self.q_cond is None:
if self._vhfopt is None:
self.build()

assert hermi == 1
Expand All @@ -174,61 +151,21 @@ def _get_jk_sr(self, dm, hermi=1, with_j=True, with_k=True):
mol = self.mol
n_dm, nao = dm.shape[:2]

dm_cond = _make_dm_cond(mol, dm, self.direct_scf_tol)
vhfopt = _CVHFOpt()
vhfopt.dm_cond = dm_cond.ctypes.data_as(ctypes.c_void_p)
vhfopt.logq_cond = self.q_cond.ctypes.data_as(ctypes.c_void_p)
vhfopt.log_cutoff = np.log(self.direct_scf_tol)

intor = mol._add_suffix(self._intor)
cintor = getattr(libcvhf, intor)
fdot = getattr(libcvhf, 'CVHFdot_nr_sr_s8')

vj = vk = None
dmsptr = []
vjkptr = []
fjk = []

if with_j:
fvj = _fpointer('CVHFnrs8_ji_s2kl')
vj = np.empty((n_dm,nao,nao))
for i in range(n_dm):
dmsptr.append(dm[i].ctypes.data_as(ctypes.c_void_p))
vjkptr.append(vj[i].ctypes.data_as(ctypes.c_void_p))
fjk.append(fvj)

if with_k:
fvk = _fpointer('CVHFnrs8_li_s2kj')
vk = np.empty((n_dm,nao,nao))
for i in range(n_dm):
dmsptr.append(dm[i].ctypes.data_as(ctypes.c_void_p))
vjkptr.append(vk[i].ctypes.data_as(ctypes.c_void_p))
fjk.append(fvk)

shls_slice = (ctypes.c_int*8)(*([0, mol.nbas]*4))
ao_loc = mol.ao_loc
n_ops = len(dmsptr)
comp = 1
if with_j and with_k:
out = np.empty((2*n_dm, nao, nao))
vj = out[:n_dm]
vk = out[n_dm:]
elif with_k:
vj = out = np.empty((n_dm, nao, nao))
elif with_k:
vk = out = np.empty((n_dm, nao, nao))
else:
return vj, vk

with mol.with_short_range_coulomb(self.omega):
libcvhf.CVHFnr_sr_direct_drv(
cintor, fdot, (ctypes.c_void_p*n_ops)(*fjk),
(ctypes.c_void_p*n_ops)(*dmsptr),
(ctypes.c_void_p*n_ops)(*vjkptr),
ctypes.c_int(n_ops), ctypes.c_int(comp),
shls_slice, ao_loc.ctypes.data_as(ctypes.c_void_p),
self._cintopt, ctypes.byref(vhfopt),
mol._atm.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(mol.natm),
mol._bas.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(mol.nbas),
mol._env.ctypes.data_as(ctypes.c_void_p))

if with_j:
for i in range(n_dm):
lib.hermi_triu(vj[i], 1, inplace=True)
if with_k:
if hermi != 0:
for i in range(n_dm):
lib.hermi_triu(vk[i], hermi, inplace=True)
_vhf.direct(dm, mol._atm, mol._bas, mol._env, self._vhfopt, hermi,
mol.cart, with_j, with_k, out, optimize_sr=True)
logger.timer(mol, 'short range part vj and vk', *cpu0)
return vj, vk

Expand Down Expand Up @@ -294,13 +231,48 @@ def _get_jk_lr(self, dm, hermi=1, with_j=True, with_k=True):

LRDF = LRDensityFitting

def _make_dm_cond(mol, dm, direct_scf_tol):
assert dm.ndim == 3
ao_loc = mol.ao_loc
dm_cond = [lib.condense('NP_absmax', d, ao_loc, ao_loc) for d in dm]
dm_cond = np.max(dm_cond, axis=0)
dm_cond += MIN_CUTOFF # to remove divide-by-zero error
return np.asarray(dm_cond, order='C', dtype=np.float32)
class _VHFOpt(_vhf._VHFOpt):
def __init__(self, mol, intor=None, prescreen='CVHFnoscreen',
qcondname=None, dmcondname=None, omega=None):
assert omega is not None
with mol.with_short_range_coulomb(omega):
_vhf._VHFOpt.__init__(self, mol, intor, prescreen, qcondname, dmcondname)
self.omega = omega
self._this.direct_scf_cutoff = np.log(1e-14)

@property
def direct_scf_tol(self):
return np.exp(self._this.direct_scf_cutoff)
@direct_scf_tol.setter
def direct_scf_tol(self, v):
self._this.direct_scf_cutoff = np.log(v)

def init_cvhf_direct(self, mol, intor=None, qcondname=None):
nbas = mol.nbas
q_cond = np.empty((6,nbas,nbas), dtype=np.float32)
ao_loc = mol.ao_loc
cintopt = self._cintopt
with mol.with_short_range_coulomb(self.omega):
libcvhf.CVHFsetnr_sr_direct_scf(
libcvhf.int2e_sph, cintopt,
q_cond.ctypes.data_as(ctypes.c_void_p),
ao_loc.ctypes.data_as(ctypes.c_void_p),
mol._atm.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(mol.natm),
mol._bas.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(mol.nbas),
mol._env.ctypes.data_as(ctypes.c_void_p))

self._q_cond = q_cond
logq_cond = q_cond.ctypes.data_as(ctypes.c_void_p)
self._this.q_cond = logq_cond

def set_dm(self, dm, atm=None, bas=None, env=None):
assert dm[0].ndim == 2
ao_loc = self.mol.ao_loc_nr()
dm_cond = [lib.condense('NP_absmax', d, ao_loc, ao_loc) for d in dm]
dm_cond = np.max(dm_cond, axis=0)
dm_cond += MIN_CUTOFF # to remove divide-by-zero error
self._dm_cond = np.asarray(dm_cond, order='C', dtype=np.float32)
self._this.dm_cond = self._dm_cond.ctypes.data_as(ctypes.c_void_p)

def _quadrature_roots(n, omega):
rs, ws = scipy.special.roots_hermite(n*2)
Expand Down
2 changes: 1 addition & 1 deletion pyscf/lrdf/test/test_df_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,5 @@ def test_rhf_grad(self):
self.assertAlmostEqual(abs(g1-ref).max(), 0, 5)

if __name__ == "__main__":
print("Full Tests for df.grad")
print("Full Tests for lrdf.grad")
unittest.main()
2 changes: 1 addition & 1 deletion pyscf/lrdf/test/test_lrdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pyscf import lib
from pyscf import gto
from pyscf.scf import hf
from pyscf.df import lrdf
from pyscf.lrdf import lrdf


class KnownValues(unittest.TestCase):
Expand Down
Loading

0 comments on commit ae1bc17

Please sign in to comment.