Documentation | Preprint | Getting Started | Installation | Contents | Contact
Neural Networks are promising models for enhancing the accuracy of classical molecular simulations. However, the training of accurate models is challenging. chemtrain is a framework for learning sophisticated Neural Network potential models by combining customizable training routines with advanced training algorithms. This combination enables the inclusion of high-quality reference data from simulations and experiments and lowering the computational demand of training through complementing algorithms with different advantages.
chemtrain is written in JAX, integrating with the differentiable MD engine JAX, M.D. Therefore, chemtrain leverages end-to-end differentiable physics and hardware acceleration through GPUs to provide flexibility at scale.
To get started with chemtrain and with the most important algorithms, we provide simple toy examples. These examples are simple to run on the CPU and sufficient to illustrate the basic concepts of the algorithms:
For a more extensive overview of implemented algorithms, please refer to the
documentation of the trainers
module.
To see the usage of chemtrain in real examples, we implemented the training procedures of some recent papers:
- CG Alaninine Dipeptide in Implicit Water
- CG Water on Structural Data
- AT Titanium on Fused Simulation and Experimental Data
We recommend viewing the examples in the reference documentation.
chemtrain can be installed with pip:
pip install chemtrain --upgrade
The above command installs JAX for CPU. Running chemtrain on the GPU requires the installation of a special JAX version. Please follow the JAX Installation Instructions.
Note: Chemtrain installs
jax == 0.4.30
which is, in principle, incompatible withjax_md <= 0.1.29
but resolves an XLA issue which can prevent training. By importingchemtrain
or thejax_md_mod
module before importingjax_md
, the compatibility is restored by a simple patch.
Some parts of chemtrain require additional packages.
To install these, provide the all
option.
pip install 'chemtrain[all]' --upgrade
The lines below install chemtrain from source for development purposes.
git clone git@github.com:tummfm/chemtrain.git
pip install -e '.[all,docs,test]'
This command additionally installs the requirements to run the tests
pytest tests
and to build the documentation (e.g., in HTML)
make -C docs html
Within the repository, we provide the following directories:
chemtrain/
: Source code of the chemtrain package. The package consists of the
following submodules:
data
Loading and preprocessing of microscopic reference dataensemble
Sampling from and evaluating quantities for ensembleslearn
Lower level implementations of training algorithmsquantity
Learnable microscopic and macroscopic quantitiestrainers
High-level API to training algorithms
docs/
: Source code of the documentation.
examples/
: Example Jupyter Notebooks as provided in the documentation. Additionally,
the examples/data/
folder contains some example data for the toy examples.
The other Jupyter Notebooks download data automatically from the sources
provided in the original papers.
jax_md_mod/
: Source code of the JAX, M.D. modifications. In the long term, we aim to integrate these modifications into the main JAX, M.D. repository.
tests/
: Unit test for the chemtrain package, supplementing the testing trough
a reproduction of published paper results.
If you use chemtrain, please cite the following preprint:
@misc{fuchs2024chemtrain,
title={chemtrain: Learning Deep Potential Models via Automatic Differentiation and Statistical Physics},
author={Paul Fuchs and Stephan Thaler and Sebastien Röcken and Julija Zavadlav},
year={2024},
eprint={2408.15852},
archivePrefix={arXiv},
primaryClass={physics.chem-ph},
url={https://arxiv.org/abs/2408.15852},
}
Contributions are always welcome! Please open a pull request to discuss the code additions.
For questions or discussions, please open an Issue on GitHub.