Skip to content

Commit

Permalink
Change jaxopt root finder to be jit-able
Browse files Browse the repository at this point in the history
  • Loading branch information
WangXinyan940 committed Oct 24, 2023
1 parent a03c841 commit 9b03c2b
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ut.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.9]
python-version: [3.10]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand Down
15 changes: 12 additions & 3 deletions dmff/admp/qeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@

try:
import jaxopt
try:
from jaxopt import Broyden
JAXOPT_OLD = False
except ImportError:
JAXOPT_OLD = True
print("jaxopt is too old. The QEQ potential function cannot be jitted. Please update jaxopt to the latest version for speed concern.")
except ImportError:
print("jaxopt not found, QEQ cannot be used.")
import jax
Expand Down Expand Up @@ -283,9 +289,12 @@ def get_energy(positions, box, pairs, mscales, eta, chi, J, aux=None):
b_value = jnp.concatenate((aux["q"], aux["lagmt"]))
else:
b_value = jnp.concatenate([self.init_q, self.init_lagmt])
rf = jaxopt.ScipyRootFinding(
optimality_fun=E_grads, method="hybr", jit=False, tol=1e-10
)
if JAXOPT_OLD:
rf = jaxopt.ScipyRootFinding(
optimality_fun=E_grads, method="hybr", jit=False, tol=1e-10
)
else:
rf = jaxopt.Broyden(fun=E_grads, tol=1e-10)
b_0, _ = rf.run(
b_value,
chi,
Expand Down
34 changes: 29 additions & 5 deletions tests/test_admp/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,14 @@ def test_init(self):
H2 = Hamiltonian('tests/data/admp_nonpol.xml')
pdb = app.PDBFile('tests/data/water_dimer.pdb')
potential = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5)
potential_aux = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5, has_aux=True)
potential1 = H1.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5)
potential2 = H2.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5)

yield potential, potential1, potential2, H.paramset, H1.paramset, H2.paramset
yield potential, potential_aux, potential1, potential2, H.paramset, H1.paramset, H2.paramset

def test_ADMPPmeForce(self, pot_prm):
potential, potential1, potential2, paramset, paramset1, paramset2 = pot_prm
potential, potential_aux, potential1, potential2, paramset, paramset1, paramset2 = pot_prm
rc = 0.4
pdb = app.PDBFile('tests/data/water_dimer.pdb')
positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)
Expand All @@ -55,7 +56,7 @@ def test_ADMPPmeForce(self, pot_prm):


def test_ADMPPmeForce_jit(self, pot_prm):
potential, potential1, potential2, paramset, paramset1, paramset2 = pot_prm
potential, potential_aux, potential1, potential2, paramset, paramset1, paramset2 = pot_prm
rc = 0.4
pdb = app.PDBFile('tests/data/water_dimer.pdb')
positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)
Expand All @@ -73,10 +74,33 @@ def test_ADMPPmeForce_jit(self, pot_prm):
energy, grad = j_pot_pme(positions, box, pairs, paramset.parameters)
print('hahahah', energy)
np.testing.assert_almost_equal(energy, -35.71585296268245, decimal=1)

def test_ADMPPmeForce_aux(self, pot_prm):
potential, potential_aux, potential1, potential2, paramset, paramset1, paramset2 = pot_prm
rc = 0.4
pdb = app.PDBFile('tests/data/water_dimer.pdb')
positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)
positions = jnp.array(positions)
a, b, c = pdb.topology.getPeriodicBoxVectors().value_in_unit(unit.nanometer)
box = jnp.array([a, b, c])
covalent_map = potential.meta["cov_map"]
# neighbor list
nblist = NeighborList(box, rc, covalent_map)
nblist.allocate(positions)
pairs = nblist.pairs

aux = {
"U_ind": jnp.zeros((len(positions),3)),
}
pot = potential_aux.getPotentialFunc(names=["ADMPPmeForce"])
j_pot_pme = jit(value_and_grad(pot, has_aux=True))
(energy, grad), aux = j_pot_pme(positions, box, pairs, paramset.parameters, aux=aux)
print('hahahah', energy)
np.testing.assert_almost_equal(energy, -35.71585296268245, decimal=1)


def test_ADMPPmeForce_mono(self, pot_prm):
potential, potential1, potential2, paramset, paramset1, paramset2 = pot_prm
potential, potential_aux, potential1, potential2, paramset, paramset1, paramset2 = pot_prm
rc = 0.4
pdb = app.PDBFile('tests/data/water_dimer.pdb')
positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)
Expand All @@ -97,7 +121,7 @@ def test_ADMPPmeForce_mono(self, pot_prm):


def test_ADMPPmeForce_nonpol(self, pot_prm):
potential, potential1, potential2, paramset, paramset1, paramset2 = pot_prm
potential, potential_aux, potential1, potential2, paramset, paramset1, paramset2 = pot_prm
rc = 0.4
pdb = app.PDBFile('tests/data/water_dimer.pdb')
positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)
Expand Down
66 changes: 63 additions & 3 deletions tests/test_admp/test_qeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dmff.api.xmlio import XMLIO
from dmff import NeighborList
import jax.numpy as jnp
import jax
import numpy as np


Expand Down Expand Up @@ -38,7 +39,7 @@ def test_qeq_energy():
np.testing.assert_almost_equal(energy, -37.84692763, decimal=3)


def test_qeq_energy2():
def test_qeq_energy_2res():
rc = 0.6
xml = XMLIO()
xml.loadXML("tests/data/qeq2.xml")
Expand All @@ -63,7 +64,6 @@ def test_qeq_energy2():

nblist = NeighborList(box, rc, dmfftop.buildCovMat())
pairs = nblist.allocate(pos)
print(pairs)

const_list = []
const_list.append([])
Expand All @@ -86,4 +86,64 @@ def test_qeq_energy2():
}
energy, aux = efunc(pos, box, pairs, hamilt.paramset.parameters, aux=aux)
print(aux)
np.testing.assert_almost_equal(energy, 4817.295171, decimal=3)
np.testing.assert_almost_equal(energy, 4817.295171, decimal=2)

grad = jax.grad(efunc, argnums=0, has_aux=True)
gradient, aux = grad(pos, box, pairs, hamilt.paramset.parameters, aux=aux)
print(gradient)


def test_qeq_energy_2res_jit():
rc = 0.6
xml = XMLIO()
xml.loadXML("tests/data/qeq2.xml")
res = xml.parseResidues()
charges = [a["charge"] for a in res[0]["particles"]] + [a["charge"] for a in res[1]["particles"]]
charges = np.zeros((len(charges),))
types = [a["type"] for a in res[0]["particles"]] + [a["type"] for a in res[1]["particles"]]

pdb = app.PDBFile("tests/data/qeq2.pdb")
top = pdb.topology
dmfftop = DMFFTopology(from_top=top)
atoms = [a for a in dmfftop.atoms()]
pos = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)
pos = jnp.array(pos)
box = dmfftop.getPeriodicBoxVectors()
hamilt = Hamiltonian("tests/data/qeq2.xml")

atoms = [a for a in dmfftop.atoms()]
for na, a in enumerate(atoms):
a.meta["charge"] = charges[na]
a.meta["type"] = types[na]

nblist = NeighborList(box, rc, dmfftop.buildCovMat())
pairs = nblist.allocate(pos)

const_list = []
const_list.append([])
for ii in range(144):
const_list[-1].append(ii)
const_list.append([])
for ii in range(144, 147):
const_list[-1].append(ii)
const_val = [0.0, 0.0]

pot = hamilt.createPotential(dmfftop, nonbondedCutoff=rc*unit.nanometer, nonbondedMethod=app.PME,
ethresh=1e-3, neutral=True, slab=True, constQ=True,
const_list=const_list, const_vals=const_val,
has_aux=True
)
efunc = jax.jit(pot.getPotentialFunc())
grad = jax.jit(jax.grad(efunc, argnums=0, has_aux=True))
aux = {
"q": jnp.array(charges),
"lagmt": jnp.array([1.0, 1.0])
}
print("Start computing energy and force.")
energy, aux = efunc(pos, box, pairs, hamilt.paramset.parameters, aux=aux)
print(aux)
np.testing.assert_almost_equal(energy, 4817.295171, decimal=2)

grad = jax.grad(efunc, argnums=0, has_aux=True)
gradient, aux = grad(pos, box, pairs, hamilt.paramset.parameters, aux=aux)
print(gradient)

0 comments on commit 9b03c2b

Please sign in to comment.