Skip to content

Commit

Permalink
Test with ground truth
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonLeeJSL committed Jan 29, 2024
1 parent a470091 commit 100288e
Showing 1 changed file with 51 additions and 8 deletions.
59 changes: 51 additions & 8 deletions tests/native/xla/eri_kernel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,29 +93,74 @@ def num_unique_ijkl(n):
# pyscf_mol = get_pyscf_mol("C80-D5d-1", "sto-3g")
# pyscf_mol = get_pyscf_mol("C90-D5h-1", "sto-3g")
# pyscf_mol = get_pyscf_mol("C100-D5d-1", "sto-3g")
pyscf_mol = get_pyscf_mol("C180-0", "sto-3g")
# pyscf_mol = get_pyscf_mol("C180-0", "sto-3g")
# pyscf_mol = get_pyscf_mol("C240-0", "sto-3g")
# pyscf_mol = get_pyscf_mol("C320-0", "sto-3g")
# pyscf_mol = get_pyscf_mol("C500-0", "sto-3g")
# pyscf_mol = get_pyscf_mol("O2", "6-31G")
# pyscf_mol = get_pyscf_mol("O2", "sto-3g")
pyscf_mol = get_pyscf_mol("O2", "sto-3g")
mol = Mol.from_pyscf_mol(pyscf_mol)
cgto = CGTO.from_mol(mol)
self.s = angular_stats.angular_static_args(*[cgto.pgto.angular] * 4)
self.cgto = cgto
self.ab_idx_counts = symmetry.get_2c_sym_idx(cgto.n_pgtos)
n_2c_idx = len(self.ab_idx_counts)


key = jax.random.PRNGKey(42)
self.Mo_coeff = jax.random.normal(key,(2, self.cgto.n_cgtos, self.cgto.n_cgtos))



def test_abcd(self) -> None:
out_abcd = compute_hartree_test(self.cgto, self.s, self.Mo_coeff)
out_abcd = compute_hartree_test(self.cgto, self.s, self.Mo_coeff, eps=1e-10)
print(out_abcd)


# # Compute Ground Truth
# ab_idx, counts_ab = self.ab_idx_counts[:, :2], self.ab_idx_counts[:, 2]
# self.abab_idx_count = jnp.hstack([ab_idx, ab_idx,
# counts_ab[:, None]*counts_ab[:, None]]).astype(int)

# n_2c_idx = len(self.ab_idx_counts)
# num_4c_idx = symmetry.num_unique_ij(n_2c_idx)
# self.num_4c_idx = num_4c_idx
# batch_size: int = 2**23
# i = 0
# start = batch_size * i
# end = num_4c_idx
# slice_size = num_4c_idx - start
# start_idx = symmetry.get_triu_ij_from_idx(n_2c_idx, start)
# end_idx = symmetry.get_triu_ij_from_idx(n_2c_idx, end)
# self.abcd_idx_counts = symmetry.get_4c_sym_idx_range(
# self.ab_idx_counts, n_2c_idx, start_idx, end_idx, slice_size
# )

# self.n_segs = symmetry.num_unique_ijkl(self.cgto.n_cgtos)
# self.cgto_seg_id = symmetry.get_cgto_segment_id_sym(
# self.abcd_idx_counts[:, :-1], self.cgto.cgto_splits, four_center=True
# )

# nmo = self.cgto.n_cgtos # assuming same number of MOs and AOs
# self.mo_abcd_idx_counts = symmetry.get_4c_sym_idx(nmo)
# self.rdm1 = get_rdm1(self.Mo_coeff).sum(0)
# self.rdm1_ab = self.rdm1[self.mo_abcd_idx_counts[:, 0], self.mo_abcd_idx_counts[:, 1]]
# self.rdm1_cd = self.rdm1[self.mo_abcd_idx_counts[:, 2], self.mo_abcd_idx_counts[:, 3]]

def compute_hartree_test(cgto: CGTO, static_args: AngularStats, Mo_coeff_spin):
# cgto_4c_fn_gt = tensorization.tensorize_4c_cgto(electron_repulsion_integral, self.s)
# e_raw_gt = cgto_4c_fn_gt(self.cgto, self.abcd_idx_counts, self.cgto_seg_id, self.n_segs)
# hartree_e_raw_gt = jnp.sum(e_raw_gt * self.rdm1_ab * self.rdm1_cd)
# print(hartree_e_raw_gt)



def compute_hartree_test(cgto: CGTO, static_args: AngularStats, Mo_coeff_spin, eps = 1e-10, thread_load = 2**10):
"""Compute contracted ERI
Args:
cgto: cgto of molecule
static_args: statis arguments for orbitals
Mo_coeff_spin: molecule coefficients with spin
"""
l_xyz = jnp.sum(cgto.pgto.angular, 1)
orig_idx = jnp.argsort(l_xyz)

Expand Down Expand Up @@ -178,8 +223,6 @@ def compute_hartree_test(cgto: CGTO, static_args: AngularStats, Mo_coeff_spin):
output = 0
for i in range(6):
for j in range(i, 6):
eps = 1e-10
thread_load = 2**10
sorted_ab_idx = sorted_idx[i]
sorted_cd_idx = sorted_idx[j]
if len(sorted_ab_idx) == 0 or len(sorted_cd_idx) == 0:
Expand Down

0 comments on commit 100288e

Please sign in to comment.