Skip to content

Commit

Permalink
Add script for DMFF model saving. (#109)
Browse files Browse the repository at this point in the history
* Add issue templates for feature request and bug-report

* Add script for dmff model saving.

* Remove issue template from devel branch.

* debug workflow

* remove debug

* Update ut.yml. Install mdtraj by conda.
  • Loading branch information
dingye18 authored Aug 24, 2023
1 parent c8aa1f1 commit 9b01bfe
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 2 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ut.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ jobs:
- name: Install Dependencies
run: |
source $CONDA/bin/activate
conda create -n dmff python=${{ matrix.python-version }} numpy openmm==7.7.0 pytest rdkit biopandas openbabel -c conda-forge
conda create -n dmff -y python=${{ matrix.python-version }} numpy openmm==7.7.0 mdtraj=1.9.7 pytest rdkit biopandas openbabel -c conda-forge
conda activate dmff
pip install --upgrade pip
pip install jax==0.3.15 jaxlib==0.3.15 jax_md==0.2.0 mdtraj==1.9.7 pymbar==4.0.1 tqdm
pip install jax==0.3.15 jaxlib==0.3.15 jax_md==0.2.0 pymbar==4.0.1 chex==0.1.4 dm-haiku==0.0.7 tqdm
- name: Install DMFF
run: |
source $CONDA/bin/activate dmff && pip install .
Expand Down
76 changes: 76 additions & 0 deletions backend/save_dmff2tf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import dmff
from dmff import NeighborList
import jax
import jax.numpy as jnp
from jax.experimental import jax2tf
# The model is saved in double precision by default.
# Since forces accuracy in double precision is needed in molecular dynamics simulations,
# we need to enable double precision in JAX.
from jax import config
config.update("jax_enable_x64", True)
import openmm.app as app
import openmm.unit as unit
import tensorflow as tf

import os
import argparse

gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)

def create_dmff_potential(input_pdb_file, ff_xml_files):
pdb = app.PDBFile(input_pdb_file)
h = dmff.Hamiltonian(*ff_xml_files)
pot = h.createPotential(pdb.topology,
nonbondedMethod=app.PME,
nonbondedCutoff=1.2 *
unit.nanometer)
pot_func = pot.getPotentialFunc()
a, b, c = pdb.topology.getPeriodicBoxVectors()
a = a.value_in_unit(unit.nanometer)
b = b.value_in_unit(unit.nanometer)
c = c.value_in_unit(unit.nanometer)

engrad = jax.value_and_grad(pot_func, 0)

covalent_map = h.getGenerators()[-1].covalent_map

def potential_engrad(positions, box, pairs):
if jnp.shape(pairs)[-1] == 2:
nbond = covalent_map[pairs[:, 0], pairs[:, 1]]
pairs = jnp.concatenate([pairs, nbond[:, None]], axis=1)

return engrad(positions, box, pairs, h.paramtree)

return pdb, potential_engrad, covalent_map, pot, h


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_pdb", dest="input_pdb", help="input pdb file. Box information is required in the pdb file.")
parser.add_argument("--xml_files", dest="xml_files", nargs="+", help=".xml files with parameters are derived from DMFF.")
parser.add_argument("--output", dest="output", help="output directory")
args = parser.parse_args()

input_pdb = args.input_pdb
ff_xml_files = args.xml_files
output_dir = args.output
if output_dir[-1] == "/":
output_dir = output_dir[:-1]
if not os.path.exists(output_dir):
os.mkdir(output_dir)

pdb, pot_grad, covalent_map, pot, h = create_dmff_potential(input_pdb, ff_xml_files)

natoms = pdb.getTopology().getNumAtoms()

f_tf = jax2tf.convert(
jax.jit(pot_grad),
polymorphic_shapes=["("+str(natoms)+", 3)", "(3, 3)", "(b, 2)"]
)
dmff_model = tf.Module()
dmff_model.f = tf.function(f_tf, autograph=False,
input_signature=[tf.TensorSpec(shape=[natoms,3], dtype=tf.float64), tf.TensorSpec(shape=[3,3], dtype=tf.float64), tf.TensorSpec(shape=tf.TensorShape([None, 2]), dtype=tf.int32)])

tf.saved_model.save(dmff_model, output_dir, options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))

0 comments on commit 9b01bfe

Please sign in to comment.