From 100288e8007c8d85c0f5b4aadbe6dadc39c4b160 Mon Sep 17 00:00:00 2001 From: JasonLeeJsL Date: Mon, 29 Jan 2024 18:10:53 +0800 Subject: [PATCH] Test with ground truth --- tests/native/xla/eri_kernel_test.py | 59 +++++++++++++++++++++++++---- 1 file changed, 51 insertions(+), 8 deletions(-) diff --git a/tests/native/xla/eri_kernel_test.py b/tests/native/xla/eri_kernel_test.py index 4cd8a65..babc7b5 100644 --- a/tests/native/xla/eri_kernel_test.py +++ b/tests/native/xla/eri_kernel_test.py @@ -93,29 +93,74 @@ def num_unique_ijkl(n): # pyscf_mol = get_pyscf_mol("C80-D5d-1", "sto-3g") # pyscf_mol = get_pyscf_mol("C90-D5h-1", "sto-3g") # pyscf_mol = get_pyscf_mol("C100-D5d-1", "sto-3g") - pyscf_mol = get_pyscf_mol("C180-0", "sto-3g") + # pyscf_mol = get_pyscf_mol("C180-0", "sto-3g") # pyscf_mol = get_pyscf_mol("C240-0", "sto-3g") # pyscf_mol = get_pyscf_mol("C320-0", "sto-3g") # pyscf_mol = get_pyscf_mol("C500-0", "sto-3g") # pyscf_mol = get_pyscf_mol("O2", "6-31G") - # pyscf_mol = get_pyscf_mol("O2", "sto-3g") + pyscf_mol = get_pyscf_mol("O2", "sto-3g") mol = Mol.from_pyscf_mol(pyscf_mol) cgto = CGTO.from_mol(mol) self.s = angular_stats.angular_static_args(*[cgto.pgto.angular] * 4) self.cgto = cgto self.ab_idx_counts = symmetry.get_2c_sym_idx(cgto.n_pgtos) - n_2c_idx = len(self.ab_idx_counts) - + + key = jax.random.PRNGKey(42) self.Mo_coeff = jax.random.normal(key,(2, self.cgto.n_cgtos, self.cgto.n_cgtos)) + + def test_abcd(self) -> None: - out_abcd = compute_hartree_test(self.cgto, self.s, self.Mo_coeff) + out_abcd = compute_hartree_test(self.cgto, self.s, self.Mo_coeff, eps=1e-10) print(out_abcd) + # # Compute Ground Truth + # ab_idx, counts_ab = self.ab_idx_counts[:, :2], self.ab_idx_counts[:, 2] + # self.abab_idx_count = jnp.hstack([ab_idx, ab_idx, + # counts_ab[:, None]*counts_ab[:, None]]).astype(int) + + # n_2c_idx = len(self.ab_idx_counts) + # num_4c_idx = symmetry.num_unique_ij(n_2c_idx) + # self.num_4c_idx = num_4c_idx + # batch_size: int = 2**23 + # i = 0 + # start = batch_size * i + # end = num_4c_idx + # slice_size = num_4c_idx - start + # start_idx = symmetry.get_triu_ij_from_idx(n_2c_idx, start) + # end_idx = symmetry.get_triu_ij_from_idx(n_2c_idx, end) + # self.abcd_idx_counts = symmetry.get_4c_sym_idx_range( + # self.ab_idx_counts, n_2c_idx, start_idx, end_idx, slice_size + # ) + + # self.n_segs = symmetry.num_unique_ijkl(self.cgto.n_cgtos) + # self.cgto_seg_id = symmetry.get_cgto_segment_id_sym( + # self.abcd_idx_counts[:, :-1], self.cgto.cgto_splits, four_center=True + # ) + + # nmo = self.cgto.n_cgtos # assuming same number of MOs and AOs + # self.mo_abcd_idx_counts = symmetry.get_4c_sym_idx(nmo) + # self.rdm1 = get_rdm1(self.Mo_coeff).sum(0) + # self.rdm1_ab = self.rdm1[self.mo_abcd_idx_counts[:, 0], self.mo_abcd_idx_counts[:, 1]] + # self.rdm1_cd = self.rdm1[self.mo_abcd_idx_counts[:, 2], self.mo_abcd_idx_counts[:, 3]] -def compute_hartree_test(cgto: CGTO, static_args: AngularStats, Mo_coeff_spin): + # cgto_4c_fn_gt = tensorization.tensorize_4c_cgto(electron_repulsion_integral, self.s) + # e_raw_gt = cgto_4c_fn_gt(self.cgto, self.abcd_idx_counts, self.cgto_seg_id, self.n_segs) + # hartree_e_raw_gt = jnp.sum(e_raw_gt * self.rdm1_ab * self.rdm1_cd) + # print(hartree_e_raw_gt) + + + +def compute_hartree_test(cgto: CGTO, static_args: AngularStats, Mo_coeff_spin, eps = 1e-10, thread_load = 2**10): + """Compute contracted ERI + + Args: + cgto: cgto of molecule + static_args: statis arguments for orbitals + Mo_coeff_spin: molecule coefficients with spin + """ l_xyz = jnp.sum(cgto.pgto.angular, 1) orig_idx = jnp.argsort(l_xyz) @@ -178,8 +223,6 @@ def compute_hartree_test(cgto: CGTO, static_args: AngularStats, Mo_coeff_spin): output = 0 for i in range(6): for j in range(i, 6): - eps = 1e-10 - thread_load = 2**10 sorted_ab_idx = sorted_idx[i] sorted_cd_idx = sorted_idx[j] if len(sorted_ab_idx) == 0 or len(sorted_cd_idx) == 0: