Skip to content

Mixture Density Networks (MDNs) for learning simulation propagators

License

Notifications You must be signed in to change notification settings

Ferg-Lab/mdn_propagator

Mixture Density Networks

GitHub Actions Build Status

This package impliments Mixture Density Networks (MDNs) for learning simulation propagators. Given a trajectory $X={x_0, x_1, x_2, \cdots , x_N}$ where $x_t \in \mathbb{R}^d$ we learn a propagator $f_{\theta}(x_t)$ as a MDN that predicts the system state $\hat{x}_{t+\tau}$ after a lag time $\tau$ $$f_{\theta}(x_t) = \hat{x}_{t+\tau}$$

Getting Started

Installation

To use mdn_propagator, you will need an environment with the following packages:

For running and plotting examples:

Once you have these packages installed, you can install mdn_propagator in the same environment using

$ pip install -e .

Usage

Once installed, you can use the package. This example generates a synthetic trajectory of Alanine Dipeptide (ADP) in the space of the backbone dihedral angles ($\phi , \psi$). More detailed examples can be found in the examples directory.

from mdn_propagator.propagator import Propagator
import torch
import numpy as np

# load data
dihedrals_data = np.load('examples/data/alanine-dipeptide-3x250ns-backbone-dihedrals.npz')
phi_psi_data = [dihedrals_data['arr_0'], dihedrals_data['arr_1'], dihedrals_data['arr_2']]
phi_psi_data = [torch.tensor(p).float() for p in phi_psi_data]

# ininstantiate the model
model = Propagator(dim = phi_psi_data[0].size(1))

# fit the model
model.fit(phi_psi_data, lag = 1, max_epochs=100)

# Generate synthetic trajectory
n_steps = int(1E6)
x = phi_psi_data[0][0][None]
syn_traj = model.gen_synthetic_traj(x, n_steps)

# Save model checkpoint
model.save('ADP.ckpt')

# Load from checkpoint
model = Propagator.load_from_checkpoint('ADP.ckpt')

image

The defulat network used for the propagator is a simple MLP. Network hyperparameters can be defined in the Propagator constructor, also see modules for more details:

from mdn_propagator.propagator import Propagator
from torch import nn

model = Propagator(dim = 10, hidden_dim = 256, n_hidden_layers = 2, activation = nn.ReLU, lr = 1e-4)

Copyright

Copyright (c) 2022, Kirill Shmilovich

Acknowledgements

Project based on the Computational Molecular Science Python Cookiecutter version 1.1.

About

Mixture Density Networks (MDNs) for learning simulation propagators

Resources

License

Code of conduct

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages