diff --git a/.github/workflows/ut.yml b/.github/workflows/ut.yml index 6838f84b0..b5f61d71a 100644 --- a/.github/workflows/ut.yml +++ b/.github/workflows/ut.yml @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.9] + python-version: [3.10] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/dmff/admp/qeq.py b/dmff/admp/qeq.py index aad6bfefd..0c221bea1 100644 --- a/dmff/admp/qeq.py +++ b/dmff/admp/qeq.py @@ -15,6 +15,12 @@ try: import jaxopt + try: + from jaxopt import Broyden + JAXOPT_OLD = False + except ImportError: + JAXOPT_OLD = True + print("jaxopt is too old. The QEQ potential function cannot be jitted. Please update jaxopt to the latest version for speed concern.") except ImportError: print("jaxopt not found, QEQ cannot be used.") import jax @@ -283,9 +289,12 @@ def get_energy(positions, box, pairs, mscales, eta, chi, J, aux=None): b_value = jnp.concatenate((aux["q"], aux["lagmt"])) else: b_value = jnp.concatenate([self.init_q, self.init_lagmt]) - rf = jaxopt.ScipyRootFinding( - optimality_fun=E_grads, method="hybr", jit=False, tol=1e-10 - ) + if JAXOPT_OLD: + rf = jaxopt.ScipyRootFinding( + optimality_fun=E_grads, method="hybr", jit=False, tol=1e-10 + ) + else: + rf = jaxopt.Broyden(fun=E_grads, tol=1e-10) b_0, _ = rf.run( b_value, chi, diff --git a/tests/test_admp/test_compute.py b/tests/test_admp/test_compute.py index be4b9d99d..880292553 100644 --- a/tests/test_admp/test_compute.py +++ b/tests/test_admp/test_compute.py @@ -28,13 +28,14 @@ def test_init(self): H2 = Hamiltonian('tests/data/admp_nonpol.xml') pdb = app.PDBFile('tests/data/water_dimer.pdb') potential = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5) + potential_aux = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5, has_aux=True) potential1 = H1.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5) potential2 = H2.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5) - yield potential, potential1, potential2, H.paramset, H1.paramset, H2.paramset + yield potential, potential_aux, potential1, potential2, H.paramset, H1.paramset, H2.paramset def test_ADMPPmeForce(self, pot_prm): - potential, potential1, potential2, paramset, paramset1, paramset2 = pot_prm + potential, potential_aux, potential1, potential2, paramset, paramset1, paramset2 = pot_prm rc = 0.4 pdb = app.PDBFile('tests/data/water_dimer.pdb') positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) @@ -55,7 +56,7 @@ def test_ADMPPmeForce(self, pot_prm): def test_ADMPPmeForce_jit(self, pot_prm): - potential, potential1, potential2, paramset, paramset1, paramset2 = pot_prm + potential, potential_aux, potential1, potential2, paramset, paramset1, paramset2 = pot_prm rc = 0.4 pdb = app.PDBFile('tests/data/water_dimer.pdb') positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) @@ -73,10 +74,33 @@ def test_ADMPPmeForce_jit(self, pot_prm): energy, grad = j_pot_pme(positions, box, pairs, paramset.parameters) print('hahahah', energy) np.testing.assert_almost_equal(energy, -35.71585296268245, decimal=1) + + def test_ADMPPmeForce_aux(self, pot_prm): + potential, potential_aux, potential1, potential2, paramset, paramset1, paramset2 = pot_prm + rc = 0.4 + pdb = app.PDBFile('tests/data/water_dimer.pdb') + positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) + positions = jnp.array(positions) + a, b, c = pdb.topology.getPeriodicBoxVectors().value_in_unit(unit.nanometer) + box = jnp.array([a, b, c]) + covalent_map = potential.meta["cov_map"] + # neighbor list + nblist = NeighborList(box, rc, covalent_map) + nblist.allocate(positions) + pairs = nblist.pairs + + aux = { + "U_ind": jnp.zeros((len(positions),3)), + } + pot = potential_aux.getPotentialFunc(names=["ADMPPmeForce"]) + j_pot_pme = jit(value_and_grad(pot, has_aux=True)) + (energy, grad), aux = j_pot_pme(positions, box, pairs, paramset.parameters, aux=aux) + print('hahahah', energy) + np.testing.assert_almost_equal(energy, -35.71585296268245, decimal=1) def test_ADMPPmeForce_mono(self, pot_prm): - potential, potential1, potential2, paramset, paramset1, paramset2 = pot_prm + potential, potential_aux, potential1, potential2, paramset, paramset1, paramset2 = pot_prm rc = 0.4 pdb = app.PDBFile('tests/data/water_dimer.pdb') positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) @@ -97,7 +121,7 @@ def test_ADMPPmeForce_mono(self, pot_prm): def test_ADMPPmeForce_nonpol(self, pot_prm): - potential, potential1, potential2, paramset, paramset1, paramset2 = pot_prm + potential, potential_aux, potential1, potential2, paramset, paramset1, paramset2 = pot_prm rc = 0.4 pdb = app.PDBFile('tests/data/water_dimer.pdb') positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) diff --git a/tests/test_admp/test_qeq.py b/tests/test_admp/test_qeq.py index 522861b13..22da4c6f3 100644 --- a/tests/test_admp/test_qeq.py +++ b/tests/test_admp/test_qeq.py @@ -5,6 +5,7 @@ from dmff.api.xmlio import XMLIO from dmff import NeighborList import jax.numpy as jnp +import jax import numpy as np @@ -38,7 +39,7 @@ def test_qeq_energy(): np.testing.assert_almost_equal(energy, -37.84692763, decimal=3) -def test_qeq_energy2(): +def test_qeq_energy_2res(): rc = 0.6 xml = XMLIO() xml.loadXML("tests/data/qeq2.xml") @@ -63,7 +64,6 @@ def test_qeq_energy2(): nblist = NeighborList(box, rc, dmfftop.buildCovMat()) pairs = nblist.allocate(pos) - print(pairs) const_list = [] const_list.append([]) @@ -86,4 +86,64 @@ def test_qeq_energy2(): } energy, aux = efunc(pos, box, pairs, hamilt.paramset.parameters, aux=aux) print(aux) - np.testing.assert_almost_equal(energy, 4817.295171, decimal=3) \ No newline at end of file + np.testing.assert_almost_equal(energy, 4817.295171, decimal=2) + + grad = jax.grad(efunc, argnums=0, has_aux=True) + gradient, aux = grad(pos, box, pairs, hamilt.paramset.parameters, aux=aux) + print(gradient) + + +def test_qeq_energy_2res_jit(): + rc = 0.6 + xml = XMLIO() + xml.loadXML("tests/data/qeq2.xml") + res = xml.parseResidues() + charges = [a["charge"] for a in res[0]["particles"]] + [a["charge"] for a in res[1]["particles"]] + charges = np.zeros((len(charges),)) + types = [a["type"] for a in res[0]["particles"]] + [a["type"] for a in res[1]["particles"]] + + pdb = app.PDBFile("tests/data/qeq2.pdb") + top = pdb.topology + dmfftop = DMFFTopology(from_top=top) + atoms = [a for a in dmfftop.atoms()] + pos = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) + pos = jnp.array(pos) + box = dmfftop.getPeriodicBoxVectors() + hamilt = Hamiltonian("tests/data/qeq2.xml") + + atoms = [a for a in dmfftop.atoms()] + for na, a in enumerate(atoms): + a.meta["charge"] = charges[na] + a.meta["type"] = types[na] + + nblist = NeighborList(box, rc, dmfftop.buildCovMat()) + pairs = nblist.allocate(pos) + + const_list = [] + const_list.append([]) + for ii in range(144): + const_list[-1].append(ii) + const_list.append([]) + for ii in range(144, 147): + const_list[-1].append(ii) + const_val = [0.0, 0.0] + + pot = hamilt.createPotential(dmfftop, nonbondedCutoff=rc*unit.nanometer, nonbondedMethod=app.PME, + ethresh=1e-3, neutral=True, slab=True, constQ=True, + const_list=const_list, const_vals=const_val, + has_aux=True + ) + efunc = jax.jit(pot.getPotentialFunc()) + grad = jax.jit(jax.grad(efunc, argnums=0, has_aux=True)) + aux = { + "q": jnp.array(charges), + "lagmt": jnp.array([1.0, 1.0]) + } + print("Start computing energy and force.") + energy, aux = efunc(pos, box, pairs, hamilt.paramset.parameters, aux=aux) + print(aux) + np.testing.assert_almost_equal(energy, 4817.295171, decimal=2) + + grad = jax.grad(efunc, argnums=0, has_aux=True) + gradient, aux = grad(pos, box, pairs, hamilt.paramset.parameters, aux=aux) + print(gradient) \ No newline at end of file