diff --git a/pyscf/lib/vhf/nr_sr_vhf.c b/pyscf/lib/vhf/nr_sr_vhf.c index 0eb5703f84..742b768a37 100644 --- a/pyscf/lib/vhf/nr_sr_vhf.c +++ b/pyscf/lib/vhf/nr_sr_vhf.c @@ -19,10 +19,402 @@ float NP_fmax(float *a, int nd, int di, int dj); int CVHFshls_block_partition(int *block_loc, int *shls_slice, int *ao_loc, int block_size); -void CVHFdot_nr_sr_s4(int (*intor)(), JKOperator **jkop, JKArray **vjk, - double **dms, double *buf, double *cache, int n_dm, - int *ishls, int *jshls, int *kshls, int *lshls, - CVHFOpt *vhfopt, IntorEnvs *envs) +void CVHFdot_sr_nrs1(int (*intor)(), JKOperator **jkop, JKArray **vjk, + double **dms, double *buf, double *cache, int n_dm, + int *ishls, int *jshls, int *kshls, int *lshls, + CVHFOpt *vhfopt, IntorEnvs *envs) +{ + int *atm = envs->atm; + int *bas = envs->bas; + double *env = envs->env; + int natm = envs->natm; + int nbas = envs->nbas; + int *ao_loc = envs->ao_loc; + CINTOpt *cintopt = envs->cintopt; + int ish0 = ishls[0]; + int ish1 = ishls[1]; + int jsh0 = jshls[0]; + int jsh1 = jshls[1]; + int ksh0 = kshls[0]; + int ksh1 = kshls[1]; + int lsh0 = lshls[0]; + int lsh1 = lshls[1]; + size_t Nbas = nbas; + size_t Nbas2 = Nbas * Nbas; + float *q_ijij = (float *)vhfopt->logq_cond; + float *q_iijj = q_ijij + Nbas2; + float *s_index = q_iijj + Nbas2; + float *xij_cond = s_index + Nbas2; + float *yij_cond = xij_cond + Nbas2; + float *zij_cond = yij_cond + Nbas2; + float *dm_cond = (float *)vhfopt->dm_cond; + float kl_cutoff, jl_cutoff, il_cutoff; + float log_cutoff = vhfopt->log_cutoff; + float omega = env[PTR_RANGE_OMEGA]; + float omega2 = omega * omega; + float dm_max0, dm_max, log_dm; + float theta, theta_ij, theta_r2, skl_cutoff; + float xij, yij, zij, xkl, ykl, zkl, dx, dy, dz, r2; + int shls[4]; + void (*pf)(double *eri, double *dm, JKArray *vjk, int *shls, + int i0, int i1, int j0, int j1, + int k0, int k1, int l0, int l1); + int notempty; + int ish, jsh, ksh, lsh, i0, j0, k0, l0, i1, j1, k1, l1, idm; + double ai, aj, ak, al, aij, akl; + + for (ish = ish0; ish < ish1; ish++) { + shls[0] = ish; + ai = env[bas(PTR_EXP,ish) + bas(NPRIM_OF,ish)-1]; + + for (jsh = jsh0; jsh < jsh1; jsh++) { + if (q_ijij[ish*Nbas+jsh] < log_cutoff) { + continue; + } + shls[1] = jsh; + aj = env[bas(PTR_EXP,jsh) + bas(NPRIM_OF,jsh)-1]; + aij = ai + aj; + theta_ij = omega2*aij / (omega2 + aij); + kl_cutoff = log_cutoff - q_ijij[ish*Nbas+jsh]; + xij = xij_cond[ish * Nbas + jsh]; + yij = yij_cond[ish * Nbas + jsh]; + zij = zij_cond[ish * Nbas + jsh]; + skl_cutoff = log_cutoff - s_index[ish * Nbas + jsh]; + +for (ksh = ksh0; ksh < ksh1; ksh++) { + if (q_iijj[ish*Nbas+ksh] < log_cutoff || + q_iijj[jsh*Nbas+ksh] < log_cutoff) { + continue; + } + shls[2] = ksh; + ak = env[bas(PTR_EXP,ksh) + bas(NPRIM_OF,ksh)-1]; + jl_cutoff = log_cutoff - q_iijj[ish*Nbas+ksh]; + il_cutoff = log_cutoff - q_iijj[jsh*Nbas+ksh]; + + dm_max0 = dm_cond[ish*nbas+jsh]; + dm_max0 = MAX(dm_max0, dm_cond[ish*nbas+ksh]); + dm_max0 = MAX(dm_max0, dm_cond[jsh*nbas+ksh]); + + for (lsh = lsh0; lsh < lsh1; lsh++) { + dm_max = dm_max0 + dm_cond[ksh*nbas+lsh] + + dm_cond[ish*nbas+lsh] + dm_cond[jsh*nbas+lsh]; + log_dm = logf(dm_max); + if (q_ijij[ksh*Nbas+lsh] + log_dm < kl_cutoff || + q_iijj[jsh*Nbas+lsh] + log_dm < jl_cutoff || + q_iijj[ish*Nbas+lsh] + log_dm < il_cutoff) { + continue; + } + + al = env[bas(PTR_EXP,lsh) + bas(NPRIM_OF,lsh)-1]; + akl = ak + al; + // theta = 1/(1/aij+1/akl+1/omega2); + theta = theta_ij*akl / (theta_ij + akl); + + xkl = xij_cond[ksh * Nbas + lsh]; + ykl = yij_cond[ksh * Nbas + lsh]; + zkl = zij_cond[ksh * Nbas + lsh]; + dx = xij - xkl; + dy = yij - ykl; + dz = zij - zkl; + r2 = dx * dx + dy * dy + dz * dz; + theta_r2 = logf(r2 + 1e-30f) + theta * r2 - log_dm; + if (theta_r2 + skl_cutoff > s_index[ksh*Nbas+lsh]) { + continue; + } + shls[3] = lsh; + notempty = (*intor)(buf, NULL, shls, + atm, natm, bas, nbas, env, cintopt, cache); + if (notempty) { + i0 = ao_loc[ish]; + j0 = ao_loc[jsh]; + k0 = ao_loc[ksh]; + l0 = ao_loc[lsh]; + i1 = ao_loc[ish+1]; + j1 = ao_loc[jsh+1]; + k1 = ao_loc[ksh+1]; + l1 = ao_loc[lsh+1]; + for (idm = 0; idm < n_dm; idm++) { + pf = jkop[idm]->contract; + (*pf)(buf, dms[idm], vjk[idm], shls, + i0, i1, j0, j1, k0, k1, l0, l1); + } + } + } +} + } + } +} + +void CVHFdot_sr_nrs2ij(int (*intor)(), JKOperator **jkop, JKArray **vjk, + double **dms, double *buf, double *cache, int n_dm, + int *ishls, int *jshls, int *kshls, int *lshls, + CVHFOpt *vhfopt, IntorEnvs *envs) +{ + if (ishls[0] > jshls[0]) { + return CVHFdot_sr_nrs1(intor, jkop, vjk, dms, buf, cache, n_dm, + ishls, jshls, kshls, lshls, vhfopt, envs); + } else if (ishls[0] < jshls[0]) { + return; + } + + int *atm = envs->atm; + int *bas = envs->bas; + double *env = envs->env; + int natm = envs->natm; + int nbas = envs->nbas; + int *ao_loc = envs->ao_loc; + CINTOpt *cintopt = envs->cintopt; + int ish0 = ishls[0]; + int ish1 = ishls[1]; + int jsh0 = jshls[0]; + int jsh1 = jshls[1]; + int ksh0 = kshls[0]; + int ksh1 = kshls[1]; + int lsh0 = lshls[0]; + int lsh1 = lshls[1]; + size_t Nbas = nbas; + size_t Nbas2 = Nbas * Nbas; + float *q_ijij = (float *)vhfopt->logq_cond; + float *q_iijj = q_ijij + Nbas2; + float *s_index = q_iijj + Nbas2; + float *xij_cond = s_index + Nbas2; + float *yij_cond = xij_cond + Nbas2; + float *zij_cond = yij_cond + Nbas2; + float *dm_cond = (float *)vhfopt->dm_cond; + float kl_cutoff, jl_cutoff, il_cutoff; + float log_cutoff = vhfopt->log_cutoff; + float omega = env[PTR_RANGE_OMEGA]; + float omega2 = omega * omega; + float dm_max0, dm_max, log_dm; + float theta, theta_ij, theta_r2, skl_cutoff; + float xij, yij, zij, xkl, ykl, zkl, dx, dy, dz, r2; + int shls[4]; + void (*pf)(double *eri, double *dm, JKArray *vjk, int *shls, + int i0, int i1, int j0, int j1, + int k0, int k1, int l0, int l1); + int notempty; + int ish, jsh, ksh, lsh, i0, j0, k0, l0, i1, j1, k1, l1, idm; + double ai, aj, ak, al, aij, akl; + + for (ish = ish0; ish < ish1; ish++) { + shls[0] = ish; + ai = env[bas(PTR_EXP,ish) + bas(NPRIM_OF,ish)-1]; + + for (jsh = jsh0; jsh <= ish; jsh++) { + if (q_ijij[ish*Nbas+jsh] < log_cutoff) { + continue; + } + shls[1] = jsh; + aj = env[bas(PTR_EXP,jsh) + bas(NPRIM_OF,jsh)-1]; + aij = ai + aj; + theta_ij = omega2*aij / (omega2 + aij); + kl_cutoff = log_cutoff - q_ijij[ish*Nbas+jsh]; + xij = xij_cond[ish * Nbas + jsh]; + yij = yij_cond[ish * Nbas + jsh]; + zij = zij_cond[ish * Nbas + jsh]; + skl_cutoff = log_cutoff - s_index[ish * Nbas + jsh]; + +for (ksh = ksh0; ksh < ksh1; ksh++) { + if (q_iijj[ish*Nbas+ksh] < log_cutoff || + q_iijj[jsh*Nbas+ksh] < log_cutoff) { + continue; + } + shls[2] = ksh; + ak = env[bas(PTR_EXP,ksh) + bas(NPRIM_OF,ksh)-1]; + jl_cutoff = log_cutoff - q_iijj[ish*Nbas+ksh]; + il_cutoff = log_cutoff - q_iijj[jsh*Nbas+ksh]; + + dm_max0 = dm_cond[ish*nbas+jsh]; + dm_max0 = MAX(dm_max0, dm_cond[ish*nbas+ksh]); + dm_max0 = MAX(dm_max0, dm_cond[jsh*nbas+ksh]); + + for (lsh = lsh0; lsh < lsh1; lsh++) { + dm_max = dm_max0 + dm_cond[ksh*nbas+lsh] + + dm_cond[ish*nbas+lsh] + dm_cond[jsh*nbas+lsh]; + log_dm = logf(dm_max); + if (q_ijij[ksh*Nbas+lsh] + log_dm < kl_cutoff || + q_iijj[jsh*Nbas+lsh] + log_dm < jl_cutoff || + q_iijj[ish*Nbas+lsh] + log_dm < il_cutoff) { + continue; + } + + al = env[bas(PTR_EXP,lsh) + bas(NPRIM_OF,lsh)-1]; + akl = ak + al; + // theta = 1/(1/aij+1/akl+1/omega2); + theta = theta_ij*akl / (theta_ij + akl); + + xkl = xij_cond[ksh * Nbas + lsh]; + ykl = yij_cond[ksh * Nbas + lsh]; + zkl = zij_cond[ksh * Nbas + lsh]; + dx = xij - xkl; + dy = yij - ykl; + dz = zij - zkl; + r2 = dx * dx + dy * dy + dz * dz; + theta_r2 = logf(r2 + 1e-30f) + theta * r2 - log_dm; + if (theta_r2 + skl_cutoff > s_index[ksh*Nbas+lsh]) { + continue; + } + shls[3] = lsh; + notempty = (*intor)(buf, NULL, shls, + atm, natm, bas, nbas, env, cintopt, cache); + if (notempty) { + i0 = ao_loc[ish]; + j0 = ao_loc[jsh]; + k0 = ao_loc[ksh]; + l0 = ao_loc[lsh]; + i1 = ao_loc[ish+1]; + j1 = ao_loc[jsh+1]; + k1 = ao_loc[ksh+1]; + l1 = ao_loc[lsh+1]; + for (idm = 0; idm < n_dm; idm++) { + pf = jkop[idm]->contract; + (*pf)(buf, dms[idm], vjk[idm], shls, + i0, i1, j0, j1, k0, k1, l0, l1); + } + } + } +} + } + } +} + +void CVHFdot_sr_nrs2kl(int (*intor)(), JKOperator **jkop, JKArray **vjk, + double **dms, double *buf, double *cache, int n_dm, + int *ishls, int *jshls, int *kshls, int *lshls, + CVHFOpt *vhfopt, IntorEnvs *envs) +{ + if (kshls[0] > lshls[0]) { + return CVHFdot_sr_nrs1(intor, jkop, vjk, dms, buf, cache, n_dm, + ishls, jshls, kshls, lshls, vhfopt, envs); + } else if (kshls[0] < lshls[0]) { + return; + } + + int *atm = envs->atm; + int *bas = envs->bas; + double *env = envs->env; + int natm = envs->natm; + int nbas = envs->nbas; + int *ao_loc = envs->ao_loc; + CINTOpt *cintopt = envs->cintopt; + int ish0 = ishls[0]; + int ish1 = ishls[1]; + int jsh0 = jshls[0]; + int jsh1 = jshls[1]; + int ksh0 = kshls[0]; + int ksh1 = kshls[1]; + int lsh0 = lshls[0]; + int lsh1 = lshls[1]; + size_t Nbas = nbas; + size_t Nbas2 = Nbas * Nbas; + float *q_ijij = (float *)vhfopt->logq_cond; + float *q_iijj = q_ijij + Nbas2; + float *s_index = q_iijj + Nbas2; + float *xij_cond = s_index + Nbas2; + float *yij_cond = xij_cond + Nbas2; + float *zij_cond = yij_cond + Nbas2; + float *dm_cond = (float *)vhfopt->dm_cond; + float kl_cutoff, jl_cutoff, il_cutoff; + float log_cutoff = vhfopt->log_cutoff; + float omega = env[PTR_RANGE_OMEGA]; + float omega2 = omega * omega; + float dm_max0, dm_max, log_dm; + float theta, theta_ij, theta_r2, skl_cutoff; + float xij, yij, zij, xkl, ykl, zkl, dx, dy, dz, r2; + int shls[4]; + void (*pf)(double *eri, double *dm, JKArray *vjk, int *shls, + int i0, int i1, int j0, int j1, + int k0, int k1, int l0, int l1); + int notempty; + int ish, jsh, ksh, lsh, i0, j0, k0, l0, i1, j1, k1, l1, idm; + double ai, aj, ak, al, aij, akl; + + for (ish = ish0; ish < ish1; ish++) { + shls[0] = ish; + ai = env[bas(PTR_EXP,ish) + bas(NPRIM_OF,ish)-1]; + + for (jsh = jsh0; jsh < jsh1; jsh++) { + if (q_ijij[ish*Nbas+jsh] < log_cutoff) { + continue; + } + shls[1] = jsh; + aj = env[bas(PTR_EXP,jsh) + bas(NPRIM_OF,jsh)-1]; + aij = ai + aj; + theta_ij = omega2*aij / (omega2 + aij); + kl_cutoff = log_cutoff - q_ijij[ish*Nbas+jsh]; + xij = xij_cond[ish * Nbas + jsh]; + yij = yij_cond[ish * Nbas + jsh]; + zij = zij_cond[ish * Nbas + jsh]; + skl_cutoff = log_cutoff - s_index[ish * Nbas + jsh]; + +for (ksh = ksh0; ksh < ksh1; ksh++) { + if (q_iijj[ish*Nbas+ksh] < log_cutoff || + q_iijj[jsh*Nbas+ksh] < log_cutoff) { + continue; + } + shls[2] = ksh; + ak = env[bas(PTR_EXP,ksh) + bas(NPRIM_OF,ksh)-1]; + jl_cutoff = log_cutoff - q_iijj[ish*Nbas+ksh]; + il_cutoff = log_cutoff - q_iijj[jsh*Nbas+ksh]; + + dm_max0 = dm_cond[ish*nbas+jsh]; + dm_max0 = MAX(dm_max0, dm_cond[ish*nbas+ksh]); + dm_max0 = MAX(dm_max0, dm_cond[jsh*nbas+ksh]); + + for (lsh = lsh0; lsh <= ksh; lsh++) { + dm_max = dm_max0 + dm_cond[ksh*nbas+lsh] + + dm_cond[ish*nbas+lsh] + dm_cond[jsh*nbas+lsh]; + log_dm = logf(dm_max); + if (q_ijij[ksh*Nbas+lsh] + log_dm < kl_cutoff || + q_iijj[jsh*Nbas+lsh] + log_dm < jl_cutoff || + q_iijj[ish*Nbas+lsh] + log_dm < il_cutoff) { + continue; + } + + al = env[bas(PTR_EXP,lsh) + bas(NPRIM_OF,lsh)-1]; + akl = ak + al; + // theta = 1/(1/aij+1/akl+1/omega2); + theta = theta_ij*akl / (theta_ij + akl); + + xkl = xij_cond[ksh * Nbas + lsh]; + ykl = yij_cond[ksh * Nbas + lsh]; + zkl = zij_cond[ksh * Nbas + lsh]; + dx = xij - xkl; + dy = yij - ykl; + dz = zij - zkl; + r2 = dx * dx + dy * dy + dz * dz; + theta_r2 = logf(r2 + 1e-30f) + theta * r2 - log_dm; + if (theta_r2 + skl_cutoff > s_index[ksh*Nbas+lsh]) { + continue; + } + shls[3] = lsh; + notempty = (*intor)(buf, NULL, shls, + atm, natm, bas, nbas, env, cintopt, cache); + if (notempty) { + i0 = ao_loc[ish]; + j0 = ao_loc[jsh]; + k0 = ao_loc[ksh]; + l0 = ao_loc[lsh]; + i1 = ao_loc[ish+1]; + j1 = ao_loc[jsh+1]; + k1 = ao_loc[ksh+1]; + l1 = ao_loc[lsh+1]; + for (idm = 0; idm < n_dm; idm++) { + pf = jkop[idm]->contract; + (*pf)(buf, dms[idm], vjk[idm], shls, + i0, i1, j0, j1, k0, k1, l0, l1); + } + } + } +} + } + } +} + +void CVHFdot_sr_nrs4(int (*intor)(), JKOperator **jkop, JKArray **vjk, + double **dms, double *buf, double *cache, int n_dm, + int *ishls, int *jshls, int *kshls, int *lshls, + CVHFOpt *vhfopt, IntorEnvs *envs) { if (ishls[0] < jshls[0] || kshls[0] < lshls[0]) { return; @@ -71,7 +463,7 @@ void CVHFdot_nr_sr_s4(int (*intor)(), JKOperator **jkop, JKArray **vjk, shls[0] = ish; ai = env[bas(PTR_EXP,ish) + bas(NPRIM_OF,ish)-1]; - for (jsh = jsh0; jsh < MIN(ish+1,jsh1); jsh++) { + for (jsh = jsh0; jsh < MIN(jsh1, ish+1); jsh++) { if (q_ijij[ish*Nbas+jsh] < log_cutoff) { continue; } @@ -99,7 +491,7 @@ for (ksh = ksh0; ksh < ksh1; ksh++) { dm_max0 = MAX(dm_max0, dm_cond[ish*nbas+ksh]); dm_max0 = MAX(dm_max0, dm_cond[jsh*nbas+ksh]); - for (lsh = lsh0; lsh < MIN(ksh+1,lsh1); lsh++) { + for (lsh = lsh0; lsh < MIN(lsh1, ksh+1); lsh++) { dm_max = dm_max0 + dm_cond[ksh*nbas+lsh] + dm_cond[ish*nbas+lsh] + dm_cond[jsh*nbas+lsh]; log_dm = logf(dm_max); @@ -149,14 +541,14 @@ for (ksh = ksh0; ksh < ksh1; ksh++) { } } -void CVHFdot_nr_sr_s8(int (*intor)(), JKOperator **jkop, JKArray **vjk, - double **dms, double *buf, double *cache, int n_dm, - int *ishls, int *jshls, int *kshls, int *lshls, - CVHFOpt *vhfopt, IntorEnvs *envs) +void CVHFdot_sr_nrs8(int (*intor)(), JKOperator **jkop, JKArray **vjk, + double **dms, double *buf, double *cache, int n_dm, + int *ishls, int *jshls, int *kshls, int *lshls, + CVHFOpt *vhfopt, IntorEnvs *envs) { if (ishls[0] > kshls[0]) { - return CVHFdot_nr_sr_s4(intor, jkop, vjk, dms, buf, cache, n_dm, - ishls, jshls, kshls, lshls, vhfopt, envs); + return CVHFdot_sr_nrs4(intor, jkop, vjk, dms, buf, cache, n_dm, + ishls, jshls, kshls, lshls, vhfopt, envs); } else if (ishls[0] < kshls[0]) { return; } else if ((ishls[1] <= jshls[0]) || (kshls[1] <= lshls[0])) { diff --git a/pyscf/lib/vhf/optimizer.c b/pyscf/lib/vhf/optimizer.c index 996bf4b6ba..7833da36f8 100644 --- a/pyscf/lib/vhf/optimizer.c +++ b/pyscf/lib/vhf/optimizer.c @@ -400,9 +400,9 @@ void CVHFsetnr_direct_scf(CVHFOpt *opt, int (*intor)(), CINTOpt *cintopt, /* * Non-relativistic 2-electron integrals */ -void CVHFset_int2e_q_cond(int (*intor)(), CINTOpt *cintopt, double *q_cond, - int *ao_loc, int *atm, int natm, - int *bas, int nbas, double *env) +void CVHFnr_int2e_q_cond(int (*intor)(), CINTOpt *cintopt, double *q_cond, + int *ao_loc, int *atm, int natm, + int *bas, int nbas, double *env) { int shls_slice[] = {0, nbas}; const int cache_size = GTOmax_cache_size(intor, shls_slice, 1, @@ -448,6 +448,13 @@ void CVHFset_int2e_q_cond(int (*intor)(), CINTOpt *cintopt, double *q_cond, } } +void CVHFset_int2e_q_cond(int (*intor)(), CINTOpt *cintopt, double *q_cond, + int *ao_loc, int *atm, int natm, + int *bas, int nbas, double *env) +{ + CVHFnr_int2e_q_cond(intor, cintopt, q_cond, ao_loc, atm, natm, bas, nbas, env); +} + void CVHFset_q_cond(CVHFOpt *opt, double *q_cond, int len) { if (opt->q_cond != NULL) { @@ -457,19 +464,10 @@ void CVHFset_q_cond(CVHFOpt *opt, double *q_cond, int len) NPdcopy(opt->q_cond, q_cond, len); } -void CVHFsetnr_direct_scf_dm(CVHFOpt *opt, double *dm, int nset, int *ao_loc, - int *atm, int natm, int *bas, int nbas, double *env) +void CVHFnr_dm_cond(double *dm_cond, double *dm, int nset, int *ao_loc, + int *atm, int natm, int *bas, int nbas, double *env) { - if (opt->dm_cond != NULL) { // NOT reuse opt->dm_cond because nset may be diff in different call - free(opt->dm_cond); - } - // nbas in the input arguments may different to opt->nbas. - // Use opt->nbas because it is used in the prescreen function - nbas = opt->nbas; - opt->dm_cond = (double *)malloc(sizeof(double) * nbas*nbas); - NPdset0(opt->dm_cond, ((size_t)nbas)*nbas); - - const size_t nao = ao_loc[nbas]; + size_t nao = ao_loc[nbas]; double dmax, tmp; size_t i, j, ish, jsh, iset; double *pdm; @@ -487,11 +485,24 @@ void CVHFsetnr_direct_scf_dm(CVHFOpt *opt, double *dm, int nset, int *ao_loc, dmax = MAX(dmax, tmp); } } } - opt->dm_cond[ish*nbas+jsh] = .5 * dmax; - opt->dm_cond[jsh*nbas+ish] = .5 * dmax; + dm_cond[ish*nbas+jsh] = .5 * dmax; + dm_cond[jsh*nbas+ish] = .5 * dmax; } } } +void CVHFsetnr_direct_scf_dm(CVHFOpt *opt, double *dm, int nset, int *ao_loc, + int *atm, int natm, int *bas, int nbas, double *env) +{ + if (opt->dm_cond != NULL) { // NOT reuse opt->dm_cond because nset may be diff in different call + free(opt->dm_cond); + } + // nbas in the input arguments may different to opt->nbas. + // Use opt->nbas because it is used in the prescreen function + nbas = opt->nbas; + opt->dm_cond = (double *)malloc(sizeof(double) * nbas*nbas); + CVHFnr_dm_cond(opt->dm_cond, dm, nset, ao_loc, atm, natm, bas, nbas, env); +} + void CVHFset_dm_cond(CVHFOpt *opt, double *dm_cond, int len) { if (opt->dm_cond != NULL) { diff --git a/pyscf/lrdf/grad/rhf.py b/pyscf/lrdf/grad/rhf.py index b2571274ad..4b67da7100 100644 --- a/pyscf/lrdf/grad/rhf.py +++ b/pyscf/lrdf/grad/rhf.py @@ -14,7 +14,9 @@ # limitations under the License. # +import numpy as np from pyscf import lib +from pyscf.scf import _vhf from pyscf.lib import logger from pyscf.lrdf import lrdf from pyscf.grad import rhf as rhf_grad @@ -35,17 +37,29 @@ def get_jk(self, mol=None, dm=None, hermi=0, omega=None): lrdf_obj = self.base.with_df omega = lrdf_obj.omega - vj, vk = rhf_grad.Gradients.get_jk(self, mol, dm, hermi, -omega) - with mol.with_range_coulomb(-omega): - vj, vk = rhf_grad.get_jk(mol, dm) + # TODO: initialize q_cond with CVHFgrad_jk_direct_scf + #vhfopt = lrdf._VHFOpt(mol, 'int2e_ip1', + # prescreen='CVHFgrad_jk_prescreen', omega=omega) + vhfopt = lrdf._VHFOpt(mol, 'int2e_ip1', omega=omega) + vhfopt._this.q_cond = lrdf_obj._vhfopt._this.q_cond + vhfopt._this.dm_cond = lrdf_obj._vhfopt._this.dm_cond + + with mol.with_short_range_coulomb(omega): + intor = mol._add_suffix('int2e_ip1') + vj, vk = _vhf.direct_mapdm(intor, # (nabla i,j|k,l) + 's2kl', # ip1_sph has k>=l, + ('lk->s1ij', 'jk->s1il'), + dm, 3, # xyz, 3 components + mol._atm, mol._bas, mol._env, vhfopt=vhfopt, + optimize_sr=True) with lrdf_obj.range_coulomb(omega): with lib.temporary_env(lrdf_obj, auxmol=lrdf_obj.lr_auxmol): vj1, vk1 = df_rhf_grad.get_jk(self, mol, dm, hermi, decompose_j2c='ED', lindep=lrdf_obj.lr_thresh) - vj += vj1 - vk += vk1 + vj = vj1 - np.asarray(vj) + vk = vk1 - np.asarray(vk) if self.auxbasis_response: vj = lib.tag_array(vj, aux=vj1.aux) vk = lib.tag_array(vk, aux=vk1.aux) diff --git a/pyscf/lrdf/lrdf.py b/pyscf/lrdf/lrdf.py index 230f380855..44019fba53 100644 --- a/pyscf/lrdf/lrdf.py +++ b/pyscf/lrdf/lrdf.py @@ -7,6 +7,7 @@ import ctypes import tempfile import numpy as np +import scipy.special from pyscf import gto from pyscf import lib from pyscf.lib import logger @@ -14,8 +15,8 @@ from pyscf.gto import ft_ao from pyscf.dft.gen_grid import LEBEDEV_NGRID, libdft from pyscf.gto.moleintor import make_cintopt -from pyscf.pbc.df.incore import libpbc -from pyscf.scf._vhf import libcvhf, _fpointer +from pyscf.scf._vhf import libcvhf +from pyscf.scf import _vhf MIN_CUTOFF = 1e-44 AUXBASIS = { @@ -23,15 +24,6 @@ 'default': [[0, [1., 1]], [1, [1., 1]], [2, [1., 1]]] } -class _CVHFOpt(ctypes.Structure): - _fields_ = [('nbas', ctypes.c_int), - ('ngrids', ctypes.c_int), - ('log_cutoff', ctypes.c_double), - ('logq_cond', ctypes.c_void_p), - ('dm_cond', ctypes.c_void_p), - ('fprescreen', ctypes.c_void_p), - ('r_vkscreen', ctypes.c_void_p)] - class LRDensityFitting(df.DF): omega = 0.1 @@ -42,18 +34,15 @@ class LRDensityFitting(df.DF): lr_dfj = True def __init__(self, mol, auxbasis=None): - self._intor = 'int2e' - self._cintopt = None - self.q_cond = None self.lr_auxmol = None self.wcoulG = None self.Gv = None + self._vhfopt = None self._last_vs = (0, 0, 0) df.DF.__init__(self, mol, auxbasis) def reset(self, mol=None): - self.q_cond = None - self._cintopt = None + self._vhfopt = None return df.DF.reset(self, mol) def dump_flags(self, verbose=None): @@ -81,24 +70,12 @@ def build(self): self.dump_flags() mol = self.mol - nbas = mol.nbas - self.q_cond = np.empty((6,nbas,nbas), dtype=np.float32) - ao_loc = mol.ao_loc omega = self.omega assert omega > 0 - with mol.with_short_range_coulomb(omega): - self._cintopt = make_cintopt( - mol._atm, mol._bas, mol._env, self._intor) - - with mol.with_integral_screen(self.direct_scf_tol**2): - libpbc.CVHFsetnr_sr_direct_scf( - libpbc.int2e_sph, self._cintopt, - self.q_cond.ctypes.data_as(ctypes.c_void_p), - ao_loc.ctypes.data_as(ctypes.c_void_p), - mol._atm.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(mol.natm), - mol._bas.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(mol.nbas), - mol._env.ctypes.data_as(ctypes.c_void_p)) + self._vhfopt = vhfopt = _VHFOpt(mol, 'int2e', omega=omega) + with mol.with_integral_screen(self.direct_scf_tol**2): + vhfopt.init_cvhf_direct(mol) cpu0 = log.timer('initializing q_cond', *cpu0) if self.lr_auxmol is None: @@ -165,7 +142,7 @@ def get_jk(self, dm, hermi=1, with_j=True, with_k=True, return vj, vk def _get_jk_sr(self, dm, hermi=1, with_j=True, with_k=True): - if self.q_cond is None: + if self._vhfopt is None: self.build() assert hermi == 1 @@ -174,61 +151,21 @@ def _get_jk_sr(self, dm, hermi=1, with_j=True, with_k=True): mol = self.mol n_dm, nao = dm.shape[:2] - dm_cond = _make_dm_cond(mol, dm, self.direct_scf_tol) - vhfopt = _CVHFOpt() - vhfopt.dm_cond = dm_cond.ctypes.data_as(ctypes.c_void_p) - vhfopt.logq_cond = self.q_cond.ctypes.data_as(ctypes.c_void_p) - vhfopt.log_cutoff = np.log(self.direct_scf_tol) - - intor = mol._add_suffix(self._intor) - cintor = getattr(libcvhf, intor) - fdot = getattr(libcvhf, 'CVHFdot_nr_sr_s8') - vj = vk = None - dmsptr = [] - vjkptr = [] - fjk = [] - - if with_j: - fvj = _fpointer('CVHFnrs8_ji_s2kl') - vj = np.empty((n_dm,nao,nao)) - for i in range(n_dm): - dmsptr.append(dm[i].ctypes.data_as(ctypes.c_void_p)) - vjkptr.append(vj[i].ctypes.data_as(ctypes.c_void_p)) - fjk.append(fvj) - - if with_k: - fvk = _fpointer('CVHFnrs8_li_s2kj') - vk = np.empty((n_dm,nao,nao)) - for i in range(n_dm): - dmsptr.append(dm[i].ctypes.data_as(ctypes.c_void_p)) - vjkptr.append(vk[i].ctypes.data_as(ctypes.c_void_p)) - fjk.append(fvk) - - shls_slice = (ctypes.c_int*8)(*([0, mol.nbas]*4)) - ao_loc = mol.ao_loc - n_ops = len(dmsptr) - comp = 1 + if with_j and with_k: + out = np.empty((2*n_dm, nao, nao)) + vj = out[:n_dm] + vk = out[n_dm:] + elif with_k: + vj = out = np.empty((n_dm, nao, nao)) + elif with_k: + vk = out = np.empty((n_dm, nao, nao)) + else: + return vj, vk with mol.with_short_range_coulomb(self.omega): - libcvhf.CVHFnr_sr_direct_drv( - cintor, fdot, (ctypes.c_void_p*n_ops)(*fjk), - (ctypes.c_void_p*n_ops)(*dmsptr), - (ctypes.c_void_p*n_ops)(*vjkptr), - ctypes.c_int(n_ops), ctypes.c_int(comp), - shls_slice, ao_loc.ctypes.data_as(ctypes.c_void_p), - self._cintopt, ctypes.byref(vhfopt), - mol._atm.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(mol.natm), - mol._bas.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(mol.nbas), - mol._env.ctypes.data_as(ctypes.c_void_p)) - - if with_j: - for i in range(n_dm): - lib.hermi_triu(vj[i], 1, inplace=True) - if with_k: - if hermi != 0: - for i in range(n_dm): - lib.hermi_triu(vk[i], hermi, inplace=True) + _vhf.direct(dm, mol._atm, mol._bas, mol._env, self._vhfopt, hermi, + mol.cart, with_j, with_k, out, optimize_sr=True) logger.timer(mol, 'short range part vj and vk', *cpu0) return vj, vk @@ -294,13 +231,48 @@ def _get_jk_lr(self, dm, hermi=1, with_j=True, with_k=True): LRDF = LRDensityFitting -def _make_dm_cond(mol, dm, direct_scf_tol): - assert dm.ndim == 3 - ao_loc = mol.ao_loc - dm_cond = [lib.condense('NP_absmax', d, ao_loc, ao_loc) for d in dm] - dm_cond = np.max(dm_cond, axis=0) - dm_cond += MIN_CUTOFF # to remove divide-by-zero error - return np.asarray(dm_cond, order='C', dtype=np.float32) +class _VHFOpt(_vhf._VHFOpt): + def __init__(self, mol, intor=None, prescreen='CVHFnoscreen', + qcondname=None, dmcondname=None, omega=None): + assert omega is not None + with mol.with_short_range_coulomb(omega): + _vhf._VHFOpt.__init__(self, mol, intor, prescreen, qcondname, dmcondname) + self.omega = omega + self._this.direct_scf_cutoff = np.log(1e-14) + + @property + def direct_scf_tol(self): + return np.exp(self._this.direct_scf_cutoff) + @direct_scf_tol.setter + def direct_scf_tol(self, v): + self._this.direct_scf_cutoff = np.log(v) + + def init_cvhf_direct(self, mol, intor=None, qcondname=None): + nbas = mol.nbas + q_cond = np.empty((6,nbas,nbas), dtype=np.float32) + ao_loc = mol.ao_loc + cintopt = self._cintopt + with mol.with_short_range_coulomb(self.omega): + libcvhf.CVHFsetnr_sr_direct_scf( + libcvhf.int2e_sph, cintopt, + q_cond.ctypes.data_as(ctypes.c_void_p), + ao_loc.ctypes.data_as(ctypes.c_void_p), + mol._atm.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(mol.natm), + mol._bas.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(mol.nbas), + mol._env.ctypes.data_as(ctypes.c_void_p)) + + self._q_cond = q_cond + logq_cond = q_cond.ctypes.data_as(ctypes.c_void_p) + self._this.q_cond = logq_cond + + def set_dm(self, dm, atm=None, bas=None, env=None): + assert dm[0].ndim == 2 + ao_loc = self.mol.ao_loc_nr() + dm_cond = [lib.condense('NP_absmax', d, ao_loc, ao_loc) for d in dm] + dm_cond = np.max(dm_cond, axis=0) + dm_cond += MIN_CUTOFF # to remove divide-by-zero error + self._dm_cond = np.asarray(dm_cond, order='C', dtype=np.float32) + self._this.dm_cond = self._dm_cond.ctypes.data_as(ctypes.c_void_p) def _quadrature_roots(n, omega): rs, ws = scipy.special.roots_hermite(n*2) diff --git a/pyscf/lrdf/test/test_df_grad.py b/pyscf/lrdf/test/test_df_grad.py index ca9e0b55f4..7987fe92b2 100644 --- a/pyscf/lrdf/test/test_df_grad.py +++ b/pyscf/lrdf/test/test_df_grad.py @@ -50,5 +50,5 @@ def test_rhf_grad(self): self.assertAlmostEqual(abs(g1-ref).max(), 0, 5) if __name__ == "__main__": - print("Full Tests for df.grad") + print("Full Tests for lrdf.grad") unittest.main() diff --git a/pyscf/lrdf/test/test_lrdf.py b/pyscf/lrdf/test/test_lrdf.py index f2b68306e3..5f70520db5 100644 --- a/pyscf/lrdf/test/test_lrdf.py +++ b/pyscf/lrdf/test/test_lrdf.py @@ -3,7 +3,7 @@ from pyscf import lib from pyscf import gto from pyscf.scf import hf -from pyscf.df import lrdf +from pyscf.lrdf import lrdf class KnownValues(unittest.TestCase): diff --git a/pyscf/scf/_vhf.py b/pyscf/scf/_vhf.py index 8476c92604..0a1093395c 100644 --- a/pyscf/scf/_vhf.py +++ b/pyscf/scf/_vhf.py @@ -35,7 +35,6 @@ def __init__(self, mol, intor=None, names of C functions defined in libcvhf module ''' self._this = ctypes.POINTER(_CVHFOpt)() - #print self._this.contents, expect ValueError: NULL pointer access if intor is None: self._intor = intor @@ -148,6 +147,111 @@ def get_dm_cond(self, shape=None): return numpy.ctypeslib.as_array(data, shape=shape) dm_cond = property(get_dm_cond) +class _VHFOpt: + def __init__(self, mol, intor=None, + prescreen='CVHFnoscreen', qcondname=None, dmcondname=None): + '''New version of VHFOpt (under development). + + If function "qcondname" is presented, the qcond (sqrt(integrals)) + and will be initialized in __init__. + + prescreen, qcondname, dmcondname can be either function pointers or + names of C functions defined in libcvhf module + ''' + self.mol = mol + self._q_cond = None + self._dm_cond = None + self._this = cvhfopt = _CVHFOpt() + cvhfopt.nbas = mol.nbas + cvhfopt.direct_scf_cutoff = 1e-14 + cvhfopt.fprescreen = _fpointer(prescreen) + cvhfopt.r_vkscreen = _fpointer('CVHFr_vknoscreen') + + if intor is None: + self._intor = intor + self._cintopt = lib.c_null_ptr() + else: + self._intor = mol._add_suffix(intor) + self._cintopt = make_cintopt(mol._atm, mol._bas, mol._env, intor) + + self._dmcondname = dmcondname + self._qcondname = qcondname + if qcondname is not None and intor is not None: + self.init_cvhf_direct(mol, intor, qcondname) + + def init_cvhf_direct(self, mol, intor, qcondname): + '''qcondname can be the function pointer or the name of a C function + defined in libcvhf module + ''' + intor = mol._add_suffix(intor) + assert intor == self._intor + cintopt = self._cintopt + ao_loc = mol.ao_loc_nr() + if isinstance(qcondname, ctypes._CFuncPtr): + fqcond = qcondname + else: + fqcond = getattr(libcvhf, qcondname) + nbas = mol.nbas + q_cond = self._q_cond = numpy.empty((nbas, nbas)) + fqcond(getattr(libcvhf, intor), cintopt, q_cond.ctypes, + ao_loc.ctypes, mol._atm.ctypes, ctypes.c_int(mol.natm), + mol._bas.ctypes, ctypes.c_int(nbas), mol._env.ctypes) + self._this.q_cond = q_cond.ctypes.data_as(ctypes.c_void_p) + self._qcondname = qcondname + + @property + def direct_scf_tol(self): + return self._this.direct_scf_cutoff + @direct_scf_tol.setter + def direct_scf_tol(self, v): + self._this.direct_scf_cutoff = v + + @property + def prescreen(self): + return self._this.fprescreen + @prescreen.setter + def prescreen(self, v): + if isinstance(v, str): + v = _fpointer(v) + self._this.fprescreen = v + + def set_dm(self, dm, atm, bas, env): + if self._dmcondname is None: + return + + mol = self.mol + if isinstance(dm, numpy.ndarray) and dm.ndim == 2: + n_dm = 1 + else: + n_dm = len(dm) + dm = numpy.asarray(dm, order='C') + ao_loc = mol.ao_loc_nr() + if isinstance(self._dmcondname, ctypes._CFuncPtr): + fdmcond = self._dmcondname + else: + fdmcond = getattr(libcvhf, self._dmcondname) + nbas = mol.nbas + dm_cond = numpy.empty((nbas, nbas)) + fdmcond(dm_cond.ctypes, dm.ctypes, ctypes.c_int(n_dm), + ao_loc.ctypes, mol._atm.ctypes, ctypes.c_int(mol.natm), + mol._bas.ctypes, ctypes.c_int(nbas), mol._env.ctypes) + self._dm_cond = dm_cond + self._this.dm_cond = dm_cond.ctypes.data_as(ctypes.c_void_p) + + def get_q_cond(self, shape=None): + '''Return an array associated to q_cond. Contents of q_cond can be + modified through this array + ''' + return self._q_cond + q_cond = property(get_q_cond) + + def get_dm_cond(self, shape=None): + '''Return an array associated to dm_cond. Contents of dm_cond can be + modified through this array + ''' + return self._dm_cond + dm_cond = property(get_dm_cond) + class SGXOpt(VHFOpt): def __init__(self, mol, intor=None, @@ -301,13 +405,7 @@ def incore(eri, dms, hermi=0, with_j=True, with_k=True): # use int2e_sph as cintor, CVHFnrs8_ij_s2kl, CVHFnrs8_jk_s2il as fjk to call # direct_mapdm def direct(dms, atm, bas, env, vhfopt=None, hermi=0, cart=False, - with_j=True, with_k=True): - c_atm = numpy.asarray(atm, dtype=numpy.int32, order='C') - c_bas = numpy.asarray(bas, dtype=numpy.int32, order='C') - c_env = numpy.asarray(env, dtype=numpy.double, order='C') - natm = ctypes.c_int(c_atm.shape[0]) - nbas = ctypes.c_int(c_bas.shape[0]) - + with_j=True, with_k=True, out=None, optimize_sr=False): dms = numpy.asarray(dms, order='C', dtype=numpy.double) dms_shape = dms.shape nao = dms_shape[-1] @@ -315,57 +413,45 @@ def direct(dms, atm, bas, env, vhfopt=None, hermi=0, cart=False, n_dm = dms.shape[0] if vhfopt is None: + cvhfopt = None + cintopt = None if cart: intor = 'int2e_cart' else: intor = 'int2e_sph' - cintopt = make_cintopt(c_atm, c_bas, c_env, intor) - cvhfopt = lib.c_null_ptr() else: vhfopt.set_dm(dms, atm, bas, env) cvhfopt = vhfopt._this cintopt = vhfopt._cintopt intor = vhfopt._intor - cintor = _fpointer(intor) - - fdrv = getattr(libcvhf, 'CVHFnr_direct_drv') - fdot = _fpointer('CVHFdot_nrs8') vj = vk = None - dmsptr = [] - vjkptr = [] - fjk = [] - + jkscripts = [] + n_jk = 0 if with_j: - fvj = _fpointer('CVHFnrs8_ji_s2kl') - vj = numpy.empty((n_dm,nao,nao)) - for i, dm in enumerate(dms): - dmsptr.append(dm.ctypes.data_as(ctypes.c_void_p)) - vjkptr.append(vj[i].ctypes.data_as(ctypes.c_void_p)) - fjk.append(fvj) - + jkscripts.extend(['ji->s2kl']*n_dm) + n_jk += 1 if with_k: if hermi == 1: - fvk = _fpointer('CVHFnrs8_li_s2kj') + jkscripts.extend(['li->s2kj']*n_dm) else: - fvk = _fpointer('CVHFnrs8_li_s1kj') - vk = numpy.empty((n_dm,nao,nao)) - for i, dm in enumerate(dms): - dmsptr.append(dm.ctypes.data_as(ctypes.c_void_p)) - vjkptr.append(vk[i].ctypes.data_as(ctypes.c_void_p)) - fjk.append(fvk) - - shls_slice = (ctypes.c_int*8)(*([0, c_bas.shape[0]]*4)) - ao_loc = make_loc(bas, intor) - n_ops = len(dmsptr) - comp = 1 - fdrv(cintor, fdot, (ctypes.c_void_p*n_ops)(*fjk), - (ctypes.c_void_p*n_ops)(*dmsptr), (ctypes.c_void_p*n_ops)(*vjkptr), - ctypes.c_int(n_ops), ctypes.c_int(comp), - shls_slice, ao_loc.ctypes.data_as(ctypes.c_void_p), cintopt, cvhfopt, - c_atm.ctypes.data_as(ctypes.c_void_p), natm, - c_bas.ctypes.data_as(ctypes.c_void_p), nbas, - c_env.ctypes.data_as(ctypes.c_void_p)) + jkscripts.extend(['li->s1kj']*n_dm) + n_jk += 1 + if n_jk == 0: + return vj, vk + + dms = list(dms) * n_jk # make n_jk copies of dms + if out is None: + out = numpy.empty((n_jk*n_dm, nao, nao)) + nr_direct_drv(intor, 's8', jkscripts, dms, 1, atm, bas, env, + cvhfopt, cintopt, out=out, optimize_sr=optimize_sr) + if with_j and with_k: + vj = out[:n_dm] + vk = out[n_dm:] + elif with_j: + vj = out + else: + vk = out if with_j: # vj must be symmetric @@ -383,91 +469,45 @@ def direct(dms, atm, bas, env, vhfopt=None, hermi=0, cart=False, # jkdescript: 'ij->s1kl', 'kl->s2ij', ... def direct_mapdm(intor, aosym, jkdescript, dms, ncomp, atm, bas, env, vhfopt=None, cintopt=None, - shls_slice=None, shls_excludes=None): - assert (aosym in ('s8', 's4', 's2ij', 's2kl', 's1', - 'aa4', 'a4ij', 'a4kl', 'a2ij', 'a2kl')) - intor = ascint3(intor) - c_atm = numpy.asarray(atm, dtype=numpy.int32, order='C') - c_bas = numpy.asarray(bas, dtype=numpy.int32, order='C') - c_env = numpy.asarray(env, dtype=numpy.double, order='C') - natm = ctypes.c_int(c_atm.shape[0]) - nbas = ctypes.c_int(c_bas.shape[0]) - + shls_slice=None, shls_excludes=None, out=None, + optimize_sr=False): if isinstance(dms, numpy.ndarray) and dms.ndim == 2: dms = dms[numpy.newaxis,:,:] + single_dm = True + else: + single_dm = False n_dm = len(dms) dms = [numpy.asarray(dm, order='C', dtype=numpy.double) for dm in dms] if isinstance(jkdescript, str): - jkdescripts = (jkdescript,) + jkscripts = (jkdescript,) else: - jkdescripts = jkdescript - njk = len(jkdescripts) + jkscripts = jkdescript + n_jk = len(jkscripts) + + # make n_jk copies of dms + dms = dms * n_jk + # make n_dm copies for each jk script + jkscripts = numpy.repeat(jkscripts, n_dm) + intor = ascint3(intor) if vhfopt is None: - cintor = _fpointer(intor) cvhfopt = lib.c_null_ptr() + cintopt = None else: vhfopt.set_dm(dms, atm, bas, env) cvhfopt = vhfopt._this cintopt = vhfopt._cintopt - cintor = getattr(libcvhf, vhfopt._intor) - if cintopt is None: - cintopt = make_cintopt(c_atm, c_bas, c_env, intor) - - if shls_slice is None: - shls_slice = [0, c_bas.shape[0]] * 4 - if shls_excludes is not None: - shls_excludes = _check_shls_excludes(shls_slice, shls_excludes) - - ao_loc = make_loc(bas, intor) - - vjk = [] - descr_sym = [x.split('->') for x in jkdescripts] - fjk = (ctypes.c_void_p*(njk*n_dm))() - dmsptr = (ctypes.c_void_p*(njk*n_dm))() - vjkptr = (ctypes.c_void_p*(njk*n_dm))() - - for i, (dmsym, vsym) in enumerate(descr_sym): - if dmsym in ('ij', 'kl', 'il', 'kj'): - sys.stderr.write('not support DM description %s, transpose to %s\n' % - (dmsym, dmsym[::-1])) - dmsym = dmsym[::-1] - f1 = _fpointer('CVHFnr%s_%s_%s'%(aosym, dmsym, vsym)) - - vshape = (n_dm, ncomp) + get_dims(vsym[-2:], shls_slice, ao_loc) - vjk.append(numpy.empty(vshape)) - for j in range(n_dm): - if dms[j].shape != get_dims(dmsym, shls_slice, ao_loc): - raise RuntimeError('dm[%d] shape %s is inconsistent with the ' - 'shls_slice shape %s' % - (j, dms[j].shape, get_dims(dmsym, shls_slice, ao_loc))) - dmsptr[i*n_dm+j] = dms[j].ctypes.data_as(ctypes.c_void_p) - vjkptr[i*n_dm+j] = vjk[i][j].ctypes.data_as(ctypes.c_void_p) - fjk[i*n_dm+j] = f1 - - if shls_excludes is None: - fdrv = getattr(libcvhf, 'CVHFnr_direct_drv') - shls_slice = (ctypes.c_int*8)(*shls_slice) - else: - fdrv = getattr(libcvhf, 'CVHFnr_direct_ex_drv') - shls_slice = (ctypes.c_int*16)(*shls_slice, *shls_excludes) - dotsym = _INTSYMAP[aosym] - fdot = _fpointer('CVHFdot_nr'+dotsym) - fdrv(cintor, fdot, fjk, dmsptr, vjkptr, - ctypes.c_int(njk*n_dm), ctypes.c_int(ncomp), - shls_slice, ao_loc.ctypes.data_as(ctypes.c_void_p), cintopt, cvhfopt, - c_atm.ctypes.data_as(ctypes.c_void_p), natm, - c_bas.ctypes.data_as(ctypes.c_void_p), nbas, - c_env.ctypes.data_as(ctypes.c_void_p)) + vjk = nr_direct_drv(intor, aosym, jkscripts, dms, ncomp, atm, bas, env, + cvhfopt, cintopt, shls_slice, shls_excludes, out, + optimize_sr=optimize_sr) + if ncomp == 1: + vjk = [v[0] for v in vjk] - if n_dm * ncomp == 1: - vjk = [v[0,0] for v in vjk] - elif n_dm == 1: - vjk = [v[0,:] for v in vjk] - elif ncomp == 1: - vjk = [v[:,0] for v in vjk] - if isinstance(jkdescript, str): + # vjk.reshape(n_jk,n_dm,...).transpose(1,0,...) + if n_dm > 1 and n_jk > 1: + vjk = [vjk[i::n_dm] for i in range(n_dm)] + elif n_jk == 1 and single_dm: vjk = vjk[0] return vjk @@ -475,10 +515,36 @@ def direct_mapdm(intor, aosym, jkdescript, # jkdescript: 'ij->s1kl', 'kl->s2ij', ... def direct_bindm(intor, aosym, jkdescript, dms, ncomp, atm, bas, env, vhfopt=None, cintopt=None, - shls_slice=None, shls_excludes=None): - assert (aosym in ('s8', 's4', 's2ij', 's2kl', 's1', - 'aa4', 'a4ij', 'a4kl', 'a2ij', 'a2kl')) + shls_slice=None, shls_excludes=None, out=None, + optimize_sr=False): intor = ascint3(intor) + if vhfopt is None: + cvhfopt = lib.c_null_ptr() + cintopt = None + else: + vhfopt.set_dm(dms, atm, bas, env) + cvhfopt = vhfopt._this + cintopt = vhfopt._cintopt + + vjk = nr_direct_drv(intor, aosym, jkdescript, dms, ncomp, atm, bas, env, + cvhfopt, cintopt, shls_slice, shls_excludes, out, + optimize_sr=optimize_sr) + if ncomp == 1: + if isinstance(jkdescript, str): + vjk = vjk[0] + else: + vjk = [v[0] for v in vjk] + return vjk + +def nr_direct_drv(intor, aosym, jkscript, + dms, ncomp, atm, bas, env, cvhfopt=None, cintopt=None, + shls_slice=None, shls_excludes=None, out=None, + optimize_sr=True): + if optimize_sr: + assert aosym in ('s8', 's4', 's2ij', 's2kl', 's1') + else: + assert aosym in ('s8', 's4', 's2ij', 's2kl', 's1', + 'aa4', 'a4ij', 'a4kl', 'a2ij', 'a2kl') c_atm = numpy.asarray(atm, dtype=numpy.int32, order='C') c_bas = numpy.asarray(bas, dtype=numpy.int32, order='C') c_env = numpy.asarray(env, dtype=numpy.double, order='C') @@ -487,23 +553,22 @@ def direct_bindm(intor, aosym, jkdescript, if isinstance(dms, numpy.ndarray) and dms.ndim == 2: dms = dms[numpy.newaxis,:,:] + assert dms[0].ndim == 2 n_dm = len(dms) dms = [numpy.asarray(dm, order='C', dtype=numpy.double) for dm in dms] - if isinstance(jkdescript, str): - jkdescripts = (jkdescript,) + + if isinstance(jkscript, str): + jkscripts = (jkscript,) else: - jkdescripts = jkdescript - njk = len(jkdescripts) - assert (njk == n_dm) + jkscripts = jkscript + assert len(jkscripts) == n_dm - if vhfopt is None: - cintor = _fpointer(intor) + if cvhfopt is None: cvhfopt = lib.c_null_ptr() - else: - vhfopt.set_dm(dms, atm, bas, env) - cvhfopt = vhfopt._this - cintopt = vhfopt._cintopt - cintor = getattr(libcvhf, vhfopt._intor) + elif isinstance(cvhfopt, ctypes.Structure): + # To make cvhfopt _VHFOpt comparable + assert cvhfopt.dm_cond and cvhfopt.q_cond + cvhfopt = ctypes.byref(cvhfopt) if cintopt is None: cintopt = make_cintopt(c_atm, c_bas, c_env, intor) @@ -515,7 +580,7 @@ def direct_bindm(intor, aosym, jkdescript, ao_loc = make_loc(bas, intor) vjk = [] - descr_sym = [x.split('->') for x in jkdescripts] + descr_sym = [x.split('->') for x in jkscripts] fjk = (ctypes.c_void_p*(n_dm))() dmsptr = (ctypes.c_void_p*(n_dm))() vjkptr = (ctypes.c_void_p*(n_dm))() @@ -527,30 +592,42 @@ def direct_bindm(intor, aosym, jkdescript, 'shls_slice shape %s' % (i, dms[i].shape, get_dims(dmsym, shls_slice, ao_loc))) vshape = (ncomp,) + get_dims(vsym[-2:], shls_slice, ao_loc) - vjk.append(numpy.empty(vshape)) + if out is None: + buf = numpy.empty(vshape) + else: + buf = numpy.ndarray(vshape, dtype=numpy.double, buffer=out[i]) + vjk.append(buf) dmsptr[i] = dms[i].ctypes.data_as(ctypes.c_void_p) vjkptr[i] = vjk[i].ctypes.data_as(ctypes.c_void_p) fjk[i] = f1 - if shls_excludes is None: - fdrv = getattr(libcvhf, 'CVHFnr_direct_drv') + omega = env[gto.PTR_RANGE_OMEGA] + if omega < 0 and optimize_sr: + assert shls_excludes is None + drv = 'CVHFnr_sr_direct_drv' + shls_slice = (ctypes.c_int*8)(*shls_slice) + elif shls_excludes is None: + drv = 'CVHFnr_direct_drv' shls_slice = (ctypes.c_int*8)(*shls_slice) else: - fdrv = getattr(libcvhf, 'CVHFnr_direct_ex_drv') + drv = 'CVHFnr_direct_ex_drv' shls_slice = (ctypes.c_int*16)(*shls_slice, *shls_excludes) + fdrv = getattr(libcvhf, drv) dotsym = _INTSYMAP[aosym] - fdot = _fpointer('CVHFdot_nr'+dotsym) + if omega < 0 and optimize_sr: + fdot = getattr(libcvhf, 'CVHFdot_sr_nr'+dotsym) + else: + fdot = getattr(libcvhf, 'CVHFdot_nr'+dotsym) + cintor = getattr(libcvhf, intor) fdrv(cintor, fdot, fjk, dmsptr, vjkptr, - ctypes.c_int(n_dm), ctypes.c_int(ncomp), - shls_slice, ao_loc.ctypes.data_as(ctypes.c_void_p), cintopt, cvhfopt, + ctypes.c_int(n_dm), ctypes.c_int(ncomp), shls_slice, + ao_loc.ctypes.data_as(ctypes.c_void_p), cintopt, cvhfopt, c_atm.ctypes.data_as(ctypes.c_void_p), natm, c_bas.ctypes.data_as(ctypes.c_void_p), nbas, c_env.ctypes.data_as(ctypes.c_void_p)) - if ncomp == 1: - vjk = [v[0] for v in vjk] - if isinstance(jkdescript, str): + if isinstance(jkscript, str): vjk = vjk[0] return vjk diff --git a/pyscf/scf/hf.py b/pyscf/scf/hf.py index 52b301475c..1f157ed402 100644 --- a/pyscf/scf/hf.py +++ b/pyscf/scf/hf.py @@ -1753,9 +1753,8 @@ def init_direct_scf(self, mol=None): # Higher accuracy is required for Schwartz inequality prescreening. cpu0 = (logger.process_clock(), logger.perf_counter()) with mol.with_integral_screen(self.direct_scf_tol**2): - opt = _vhf.VHFOpt(mol, 'int2e', 'CVHFnrs8_prescreen', - 'CVHFsetnr_direct_scf', - 'CVHFsetnr_direct_scf_dm') + opt = _vhf._VHFOpt(mol, 'int2e', 'CVHFnrs8_prescreen', + 'CVHFnr_int2e_q_cond', 'CVHFnr_dm_cond') opt.direct_scf_tol = self.direct_scf_tol logger.timer(self, 'init_direct_scf', *cpu0) return opt