This package impliments Mixture Density Networks (MDNs) for learning simulation propagators. Given a trajectory
To use mdn_propagator
, you will need an environment with the following packages:
- Python 3.7+
- PyTorch
- PyTorch Lightning
For running and plotting examples:
Once you have these packages installed, you can install mdn_propagator
in the same environment using
$ pip install -e .
Once installed, you can use the package. This example generates a synthetic trajectory of Alanine Dipeptide (ADP) in the space of the backbone dihedral angles (examples
directory.
from mdn_propagator.propagator import Propagator
import torch
import numpy as np
# load data
dihedrals_data = np.load('examples/data/alanine-dipeptide-3x250ns-backbone-dihedrals.npz')
phi_psi_data = [dihedrals_data['arr_0'], dihedrals_data['arr_1'], dihedrals_data['arr_2']]
phi_psi_data = [torch.tensor(p).float() for p in phi_psi_data]
# ininstantiate the model
model = Propagator(dim = phi_psi_data[0].size(1))
# fit the model
model.fit(phi_psi_data, lag = 1, max_epochs=100)
# Generate synthetic trajectory
n_steps = int(1E6)
x = phi_psi_data[0][0][None]
syn_traj = model.gen_synthetic_traj(x, n_steps)
# Save model checkpoint
model.save('ADP.ckpt')
# Load from checkpoint
model = Propagator.load_from_checkpoint('ADP.ckpt')
The defulat network used for the propagator is a simple MLP. Network hyperparameters can be defined in the Propagator
constructor, also see modules for more details:
from mdn_propagator.propagator import Propagator
from torch import nn
model = Propagator(dim = 10, hidden_dim = 256, n_hidden_layers = 2, activation = nn.ReLU, lr = 1e-4)
Copyright (c) 2022, Kirill Shmilovich
Project based on the Computational Molecular Science Python Cookiecutter version 1.1.