Skip to content

Commit

Permalink
Merge branch 'SSAGESLabs:main' into funnel_abf
Browse files Browse the repository at this point in the history
  • Loading branch information
gustavor101 authored Oct 2, 2024
2 parents 89b035c + 00d4bca commit fb6bfad
Show file tree
Hide file tree
Showing 23 changed files with 149 additions and 61 deletions.
3 changes: 2 additions & 1 deletion .trunk/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
*logs
*actions
*notifications
*tools
plugins
user_trunk.yaml
user.yaml
shims
tmp
36 changes: 19 additions & 17 deletions .trunk/trunk.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
version: 0.1
runtimes:
enabled:
- go@1.18.3
- node@16.14.2
- python@3.10.3
- go@1.19.5
- node@18.12.1
- python@3.10.8
actions:
enabled:
- trunk-announce
Expand All @@ -12,28 +12,30 @@ actions:
- trunk-fmt-pre-commit
- trunk-check-pre-push
cli:
version: 1.11.1
version: 1.22.5
plugins:
sources:
- id: trunk
ref: v0.0.14
ref: v1.2.1
uri: https://github.com/trunk-io/plugins
lint:
enabled:
- cspell@6.31.1
- svgo@3.0.2
- actionlint@1.6.25
- black@23.7.0
- flake8@6.0.0
- oxipng@9.1.2
- yamllint@1.35.1
- cspell@8.14.4
- svgo@3.3.2
- actionlint@1.7.1
- black@24.8.0
- flake8@7.1.1
- git-diff-check@SYSTEM
- gitleaks@8.17.0
- gitleaks@8.19.2
- hadolint@2.12.0
- isort@5.12.0
- markdownlint@0.35.0
- prettier@3.0.0
- shellcheck@0.9.0
- shfmt@3.5.0
- taplo@0.8.1
- isort@5.13.2
- markdownlint@0.41.0
- prettier@3.3.3
- shellcheck@0.10.0
- shfmt@3.6.0
- taplo@0.9.3
ignore:
- linters: [prettier]
paths:
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
<a href="https://github.com/SSAGESLabs/PySAGES/actions/workflows/trunk.yml" target="_blank">
<img src="https://github.com/SSAGESLabs/PySAGES/actions/workflows/trunk.yml/badge.svg?branch=main" alt="Trunk">
</a>
&nbsp;
<a href="https://doi.org/10.1038/s41524-023-01189-z" target="_blank">
<img src="https://img.shields.io/badge/DOI-10.1038%2Fs41524--023--01189--z-blue" alt="Cite PySAGES">
</a>
</p>
</h1>

Expand Down
3 changes: 1 addition & 2 deletions examples/hoomd-blue/spectral_abf/Butane-SpectralABF.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -638,8 +638,7 @@
"A = result[\"free_energy\"]\n",
"# Alternatively:\n",
"# fes_fn = result[\"fes_fn\"]\n",
"# A = fes_fn(mesh)\n",
"A = A.max() - A"
"# A = fes_fn(mesh)"
]
},
{
Expand Down
1 change: 0 additions & 1 deletion examples/hoomd-blue/spectral_abf/Butane-SpectralABF.md
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,6 @@ A = result["free_energy"]
# Alternatively:
# fes_fn = result["fes_fn"]
# A = fes_fn(mesh)
A = A.max() - A
```

```python colab={"base_uri": "https://localhost:8080/", "height": 302} id="7_d_XfVLLkbI" outputId="e35db259-31f8-4a3b-b1fa-7e91a8a5c88a"
Expand Down
3 changes: 1 addition & 2 deletions examples/hoomd-blue/spectral_abf/butane.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def get_args(argv):
("time-steps", "t", int, 5e5, "Number of simulation steps"),
]
parser = argparse.ArgumentParser(description="Example script to run SpectralABF")
for (name, short, T, val, doc) in available_args:
for name, short, T, val, doc in available_args:
parser.add_argument("--" + name, "-" + short, type=T, default=T(val), help=doc)

return parser.parse_args(argv)
Expand All @@ -283,7 +283,6 @@ def main(argv=[]):
mesh = result["mesh"]
fes_fn = result["fes_fn"]
A = fes_fn(mesh)
A = A.max() - A

# plot the free energy
fig, ax = plt.subplots()
Expand Down
1 change: 0 additions & 1 deletion examples/hoomd3/spectral_abf/butane.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,6 @@ def main(argv=[]):
mesh = result["mesh"]
fes_fn = result["fes_fn"]
A = fes_fn(mesh)
A = A.max() - A

# plot the free energy
fig, ax = plt.subplots()
Expand Down
3 changes: 1 addition & 2 deletions examples/openmm/abf/alanine-dipeptide_openmm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/usr/bin/env python3


import matplotlib.pyplot as plt
import numpy

Expand Down Expand Up @@ -115,7 +114,7 @@ def post_run_action(**kwargs):
def main():
cvs = [DihedralAngle((4, 6, 8, 14)), DihedralAngle((6, 8, 14, 16))]
grid = pysages.Grid(lower=(-pi, -pi), upper=(pi, pi), shape=(32, 32), periodic=True)
method = ABF(cvs, grid)
method = ABF(cvs, grid, use_pinv=True)

raw_result = pysages.run(method, generate_simulation, 25, post_run_action=post_run_action)
result = pysages.analyze(raw_result, topology=(14,))
Expand Down
1 change: 0 additions & 1 deletion examples/openmm/spectral_abf/ADP_SpectralABF.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,6 @@
"# mesh = result[\"mesh\"]\n",
"# fes_fn = result[\"fes_fn\"]\n",
"# A = fes_fn(mesh)\n",
"A = A.max() - A\n",
"A = A.reshape(grid.shape)"
]
},
Expand Down
1 change: 0 additions & 1 deletion examples/openmm/spectral_abf/ADP_SpectralABF.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@ A = result["free_energy"]
# mesh = result["mesh"]
# fes_fn = result["fes_fn"]
# A = fes_fn(mesh)
A = A.max() - A
A = A.reshape(grid.shape)
```

Expand Down
3 changes: 1 addition & 2 deletions examples/openmm/spectral_abf/alanine-dipeptide.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def get_args(argv):
("time-steps", "t", int, 5e5, "Number of simulation steps"),
]
parser = argparse.ArgumentParser(description="Example script to run Spectral ABF")
for (name, short, T, val, doc) in available_args:
for name, short, T, val, doc in available_args:
parser.add_argument("--" + name, "-" + short, type=T, default=T(val), help=doc)
return parser.parse_args(argv)

Expand Down Expand Up @@ -108,7 +108,6 @@ def main(argv=[]):

# Set min free energy to zero
A = fes_fn(xi)
A = A.max() - A
A = A.reshape(plot_grid.shape)

# plot and save free energy to a PNG file
Expand Down
5 changes: 2 additions & 3 deletions pysages/backends/openmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,12 @@ def bias(snapshot, state, sync_backend):
"""Adds the computed bias to the forces."""
if state.bias is None:
return
biases = adapt(state.bias)
# Forces may be computed asynchronously on the GPU, so we need to
# synchronize them before applying the bias.
sync_backend()
biases = adapt(state.bias)
forces = view(snapshot.forces)
biases = view(biases.block_until_ready())
forces += biases
forces += view(biases.block_until_ready())
sync_forces()

def dimensionality():
Expand Down
4 changes: 2 additions & 2 deletions pysages/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from numpy.ctypeslib import as_ctypes_type

from pysages.typing import JaxArray
from pysages.utils import dispatch
from pysages.utils import dispatch, unsafe_buffer_pointer


def cupy_helpers():
Expand Down Expand Up @@ -38,7 +38,7 @@ def view(array: JaxArray):
# NOTE: We need a more general strategy to handle
# `SharedDeviceArray`s and `GlobalDeviceArray`s.
ptype = ctypes.POINTER(as_ctypes_type(array.dtype))
addr = array.device_buffer.unsafe_buffer_pointer()
addr = unsafe_buffer_pointer(array)
ptr = ctypes.cast(ctypes.c_void_p(addr), ptype)
return numba.carray(ptr, array.shape)

Expand Down
15 changes: 9 additions & 6 deletions pysages/methods/abf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from pysages.methods.restraints import apply_restraints
from pysages.methods.utils import numpyfy_vals
from pysages.typing import JaxArray, NamedTuple
from pysages.utils import dispatch, solve_pos_def
from pysages.utils import dispatch, linear_solver


class ABFState(NamedTuple):
Expand Down Expand Up @@ -103,13 +103,19 @@ class ABF(GriddedSamplingMethod):
If provided, indicate that harmonic restraints will be applied when any
collective variable lies outside the box from `restraints.lower` to
`restraints.upper`.
use_pinv: Optional[Bool] = False
If set to True, the product `W @ p` will be estimated using
`np.linalg.pinv` rather than using the `scipy.linalg.solve` function.
This is computationally more expensive but numerically more stable.
"""

snapshot_flags = {"positions", "indices", "momenta"}

def __init__(self, cvs, grid, **kwargs):
super().__init__(cvs, grid, **kwargs)
self.N = np.asarray(self.kwargs.get("N", 500))
self.use_pinv = self.kwargs.get("use_pinv", False)

def build(self, snapshot, helpers, *args, **kwargs):
"""
Expand Down Expand Up @@ -158,6 +164,7 @@ def _abf(method, snapshot, helpers):
dt = snapshot.dt
dims = grid.shape.size
natoms = np.size(snapshot.positions, 0)
tsolve = linear_solver(method.use_pinv)
get_grid_index = build_indexer(grid)
estimate_force = build_force_estimator(method)

Expand Down Expand Up @@ -201,11 +208,7 @@ def update(state, data):
xi, Jxi = cv(data)

p = data.momenta
# The following could equivalently be computed as `linalg.pinv(Jxi.T) @ p`
# (both seem to have the same performance).
# Another option to benchmark against is
# Wp = linalg.tensorsolve(Jxi @ Jxi.T, Jxi @ p)
Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p)
Wp = tsolve(Jxi, p)
# Second order backward finite difference
dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt

Expand Down
11 changes: 9 additions & 2 deletions pysages/methods/cff.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from pysages.ml.training import NNData, build_fitting_function, convolve, normalize
from pysages.ml.utils import blackman_kernel, pack, unpack
from pysages.typing import JaxArray, NamedTuple, Tuple
from pysages.utils import dispatch, first_or_all, solve_pos_def
from pysages.utils import dispatch, first_or_all, linear_solver

# Aliases
f32 = np.float32
Expand Down Expand Up @@ -148,6 +148,11 @@ class CFF(NNSamplingMethod):
If provided, indicate that harmonic restraints will be applied when any
collective variable lies outside the box from `restraints.lower` to
`restraints.upper`.
use_pinv: Optional[Bool] = False
If set to True, the product `W @ p` will be estimated using
`np.linalg.pinv` rather than using the `scipy.linalg.solve` function.
This is computationally more expensive but numerically more stable.
"""

snapshot_flags = {"positions", "indices", "momenta"}
Expand All @@ -171,6 +176,7 @@ def __init__(self, cvs, grid, topology, kT, **kwargs):
self.fmodel = MLP(dims, dims, topology, transform=scale)
self.optimizer = kwargs.get("optimizer", default_optimizer)
self.foptimizer = kwargs.get("foptimizer", default_foptimizer)
self.use_pinv = self.kwargs.get("use_pinv", False)

def build(self, snapshot, helpers):
return _cff(self, snapshot, helpers)
Expand All @@ -187,6 +193,7 @@ def _cff(method: CFF, snapshot, helpers):
fps, _ = unpack(method.fmodel.parameters)

# Helper methods
tsolve = linear_solver(method.use_pinv)
get_grid_index = build_indexer(grid)
learn_free_energy = build_free_energy_learner(method)
estimate_force = build_force_estimator(method)
Expand Down Expand Up @@ -221,7 +228,7 @@ def update(state, data):
xi, Jxi = cv(data)
#
p = data.momenta
Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p)
Wp = tsolve(Jxi, p)
dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt
#
I_xi = get_grid_index(xi)
Expand Down
11 changes: 9 additions & 2 deletions pysages/methods/funn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from pysages.ml.training import NNData, build_fitting_function, convolve, normalize
from pysages.ml.utils import blackman_kernel, pack, unpack
from pysages.typing import JaxArray, NamedTuple, Tuple
from pysages.utils import dispatch, first_or_all, solve_pos_def
from pysages.utils import dispatch, first_or_all, linear_solver


class FUNNState(NamedTuple):
Expand Down Expand Up @@ -126,6 +126,11 @@ class FUNN(NNSamplingMethod):
If provided, indicate that harmonic restraints will be applied when any
collective variable lies outside the box from `restraints.lower` to
`restraints.upper`.
use_pinv: Optional[Bool] = False
If set to True, the product `W @ p` will be estimated using
`np.linalg.pinv` rather than using the `scipy.linalg.solve` function.
This is computationally more expensive but numerically more stable.
"""

snapshot_flags = {"positions", "indices", "momenta"}
Expand All @@ -142,6 +147,7 @@ def __init__(self, cvs, grid, topology, **kwargs):
self.model = MLP(dims, dims, topology, transform=scale)
default_optimizer = LevenbergMarquardt(reg=L2Regularization(1e-6))
self.optimizer = kwargs.get("optimizer", default_optimizer)
self.use_pinv = self.kwargs.get("use_pinv", False)

def build(self, snapshot, helpers):
return _funn(self, snapshot, helpers)
Expand All @@ -160,6 +166,7 @@ def _funn(method, snapshot, helpers):
ps, _ = unpack(method.model.parameters)

# Helper methods
tsolve = linear_solver(method.use_pinv)
get_grid_index = build_indexer(grid)
learn_free_energy_grad = build_free_energy_grad_learner(method)
estimate_free_energy_grad = build_force_estimator(method)
Expand All @@ -186,7 +193,7 @@ def update(state, data):
xi, Jxi = cv(data)
#
p = data.momenta
Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p)
Wp = tsolve(Jxi, p)
dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt
#
I_xi = get_grid_index(xi)
Expand Down
11 changes: 9 additions & 2 deletions pysages/methods/sirens.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from pysages.ml.training import NNData, build_fitting_function, convolve
from pysages.ml.utils import blackman_kernel, pack, unpack
from pysages.typing import JaxArray, NamedTuple, Tuple
from pysages.utils import dispatch, first_or_all, solve_pos_def
from pysages.utils import dispatch, first_or_all, linear_solver


class SirensState(NamedTuple): # pylint: disable=R0903
Expand Down Expand Up @@ -146,6 +146,11 @@ class Sirens(NNSamplingMethod):
If provided, indicate that harmonic restraints will be applied when any
collective variable lies outside the box from `restraints.lower` to
`restraints.upper`.
use_pinv: Optional[Bool] = False
If set to True, the product `W @ p` will be estimated using
`np.linalg.pinv` rather than using the `scipy.linalg.solve` function.
This is computationally more expensive but numerically more stable.
"""

snapshot_flags = {"positions", "indices", "momenta"}
Expand All @@ -172,6 +177,7 @@ def __init__(self, cvs, grid, topology, **kwargs):
scale = partial(_scale, grid=grid)
self.model = Siren(dims, 1, topology, transform=scale)
self.optimizer = optimizer
self.use_pinv = self.kwargs.get("use_pinv", False)

def __check_init_invariants__(self, mode, kT, optimizer):
if mode not in ("abf", "cff"):
Expand Down Expand Up @@ -202,6 +208,7 @@ def _sirens(method: Sirens, snapshot, helpers):
ps, _ = unpack(method.model.parameters)

# Helper methods
tsolve = linear_solver(method.use_pinv)
get_grid_index = build_indexer(grid)
learn_free_energy = build_free_energy_learner(method)
estimate_force = build_force_estimator(method)
Expand Down Expand Up @@ -244,7 +251,7 @@ def update(state, data):
xi, Jxi = cv(data)
#
p = data.momenta
Wp = solve_pos_def(Jxi @ Jxi.T, Jxi @ p)
Wp = tsolve(Jxi, p)
dWp_dt = (1.5 * Wp - 2.0 * state.Wp + 0.5 * state.Wp_) / dt
#
I_xi = get_grid_index(xi)
Expand Down
Loading

0 comments on commit fb6bfad

Please sign in to comment.