Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interface for restarts #254

Merged
merged 9 commits into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions examples/hoomd3/restart/restart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
#!/usr/bin/env python3

import pickle

import hoomd
import hoomd.dlext
import hoomd.md
import matplotlib.pyplot as plt
import numpy as np

import pysages
from pysages.backends import SamplingContext
from pysages.colvars import Component
from pysages.methods import HarmonicBias, HistogramLogger


def plot(xi_hist, target_hist, lim, name):
fig, ax = plt.subplots()

ax.set_xlabel(r"CV $\xi_i$")
ax.set_ylabel(r"$p(\xi_i)$")

x = np.linspace(lim[0], lim[1], xi_hist[0].shape[0])

for i in range(len(xi_hist)):
(line,) = ax.plot(x, xi_hist[i], label="i= {0}".format(i))
ax.plot(x, target_hist[i], "--", color=line.get_color())

ax.legend(loc="best")
fig.savefig(str(name) + ".png")
plt.close(fig)


def validate_hist(xi_hist, target, epsilon=0.1):
assert len(xi_hist) == len(target)
for i in range(len(xi_hist)):
val = np.sqrt(np.mean((xi_hist[i] - target[i]) ** 2))
if val > epsilon:
raise RuntimeError(f"Biased histogram deviation too large: {val} epsilon {epsilon}")


def get_target_dist(center, k, lim, bins):
x = np.linspace(lim[0], lim[1], bins)
p = np.exp(-0.5 * k * (x - center) ** 2)
# norm numerically
p *= (lim[1] - lim[0]) / np.sum(p)
return p


def generate_context(device=hoomd.device.CPU(), seed=0, gamma=1.0, **kwargs):
sim = hoomd.Simulation(device=device, seed=seed)
sim.create_state_from_gsd("start.gsd")
integrator = hoomd.md.Integrator(dt=0.01)

nl = hoomd.md.nlist.Cell(buffer=0.4)
dpd = hoomd.md.pair.DPD(nlist=nl, kT=1.0, default_r_cut=1.0)
dpd.params[("A", "A")] = dict(A=kwargs.get("A", 5.0), gamma=gamma)
integrator.forces.append(dpd)
nve = hoomd.md.methods.NVE(filter=hoomd.filter.All())
integrator.methods.append(nve)
sim.operations.integrator = integrator
return sim


def main():
cvs = [Component([0], 2)]
cvs += [Component([0], 1)]
cvs += [Component([0], 0)]

center_cv = [0.0]
center_cv += [1.0, -0.3]

k = 15
method = HarmonicBias(cvs, k, center_cv)
callback = HistogramLogger(100)

# Parameters for plotting the histograms
Lmax = 5.0
bins = 25
target_hist = []
for i in range(len(center_cv)):
target_hist.append(get_target_dist(center_cv[i], k, (-Lmax / 2, Lmax / 2), bins))
lims = [(-Lmax / 2, Lmax / 2) for i in range(3)]

# Running from a SamplingContext. This is only needed for restarting the
# simulation within the same script/notebook and shown here as an example.
# Generally, prefer starting without a SamplingContext, that is
#
# state = pysages.run(method, generate_context, timesteps, callback)
#
# instead of the two lines below
sampling_context = SamplingContext(method, generate_context, callback)
state = pysages.run(sampling_context, int(1e5)) # run a first time

# Plot the histogram so far
hist, edges = callback.get_histograms(bins=bins, range=lims)
hist_list = [
np.sum(hist, axis=(1, 2)) / (Lmax**2),
np.sum(hist, axis=(0, 2)) / (Lmax**2),
np.sum(hist, axis=(0, 1)) / (Lmax**2),
]
plot(hist_list, target_hist, (-Lmax / 2, Lmax / 2), 1)
validate_hist(hist_list, target_hist)

# Run a second time within the same script when using a SamplingContext
state = pysages.run(sampling_context, int(1e4))

# Plot the histogram with the newly collected info
hist, edges = callback.get_histograms(bins=bins, range=lims)
hist_list = [
np.sum(hist, axis=(1, 2)) / (Lmax**2),
np.sum(hist, axis=(0, 2)) / (Lmax**2),
np.sum(hist, axis=(0, 1)) / (Lmax**2),
]
plot(hist_list, target_hist, (-Lmax / 2, Lmax / 2), 2)

# Dump the pickle file for restart. This is the standard way to
# save a system's information to perform a restart in a new run.
with open("restart.pickle", "wb") as f:
pickle.dump(state, f)

# Load the restart file. This is how to run a pysages run from a
# previously stored state.
with open("restart.pickle", "rb") as f:
state = pickle.load(f)

# When restarting, run the system using the same generate_context function!
state = pysages.run(state, generate_context, int(1e4))

# Plot all the accumulated data
hist, edges = callback.get_histograms(bins=bins, range=lims)
hist_list = [
np.sum(hist, axis=(1, 2)) / (Lmax**2),
np.sum(hist, axis=(0, 2)) / (Lmax**2),
np.sum(hist, axis=(0, 1)) / (Lmax**2),
]
plot(hist_list, target_hist, (-Lmax / 2, Lmax / 2), 3)


if __name__ == "__main__":
main()
Binary file added examples/hoomd3/restart/start.gsd
Binary file not shown.
2 changes: 1 addition & 1 deletion pysages/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES

from .core import ( # noqa: E402, F401
ContextWrapper,
JaxMDContext,
JaxMDContextState,
SamplingContext,
supported_backends,
)
30 changes: 17 additions & 13 deletions pysages/backends/ase.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# SPDX-License-Identifier: MIT
# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES

from inspect import signature
from inspect import Parameter, signature
from typing import Callable, NamedTuple

from ase.calculators.calculator import Calculator
from jax import jit
from jax import numpy as np

from pysages.backends.core import ContextWrapper
from pysages.backends.core import SamplingContext
from pysages.backends.snapshot import (
Box,
HelperMethods,
Expand All @@ -17,7 +17,6 @@
build_data_querier,
)
from pysages.backends.utils import view
from pysages.methods import SamplingMethod
from pysages.utils import ToCPU, copy


Expand Down Expand Up @@ -48,8 +47,8 @@ def __init__(self, context, method_bundle, callback: Callable):
self._calculator = atoms.calc
self._context = context
self._biased_forces = initial_snapshot.forces
self._default_properties = list(sig["properties"].default)
self._default_changes = list(sig["system_changes"].default)
self._default_properties = list(_calculator_defaults(sig, "properties"))
self._default_changes = list(_calculator_defaults(sig, "system_changes"))
for p in ("energy", "forces"):
if p not in self._default_properties:
self._default_properties.append(p)
Expand Down Expand Up @@ -90,7 +89,7 @@ def restore(self, prev_snapshot):
atoms = self.atoms
momenta, masses = prev_snapshot.vel_mass
atoms.set_positions(prev_snapshot.positions)
atoms.set_masses(masses) # masses need to be set before momenta
atoms.set_masses(masses.flatten()) # masses need to be set before momenta
atoms.set_momenta(momenta, apply_constraint=False)
atoms.set_cell(list(prev_snapshot.box.H))
self.snapshot = prev_snapshot
Expand Down Expand Up @@ -120,6 +119,12 @@ def take_snapshot(simulation, forces=None):
return Snapshot(positions, vel_mass, forces, ids, None, Box(H, origin), dt)


def _calculator_defaults(sig, arg, default=[]):
fallback = Parameter("_", Parameter.KEYWORD_ONLY, default=default)
val = sig.get(arg, fallback).default
return val if type(val) is list else default


def build_snapshot_methods(context, sampling_method):
def indices(snapshot):
return snapshot.ids
Expand Down Expand Up @@ -153,18 +158,17 @@ class View(NamedTuple):
synchronize: Callable


def bind(
wrapped_context: ContextWrapper, sampling_method: SamplingMethod, callback: Callable, **kwargs
):
def bind(sampling_context: SamplingContext, callback: Callable, **kwargs):
"""
Entry point for the backend code, it gets called when the simulation
context is wrapped within `pysages.run`.
"""
context = wrapped_context.context
context = sampling_context.context
sampling_method = sampling_context.method
snapshot = take_snapshot(context)
helpers = build_helpers(wrapped_context.view, sampling_method)
helpers = build_helpers(sampling_context, sampling_method)
method_bundle = sampling_method.build(snapshot, helpers)
sampler = Sampler(context, method_bundle, callback)
wrapped_context.view = View((lambda: None))
wrapped_context.run = context.run
sampling_context.view = View((lambda: None))
sampling_context.run = context.run
return sampler
38 changes: 21 additions & 17 deletions pysages/backends/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,52 +58,56 @@ class JaxMDContext(NamedTuple):
dt: Float


class ContextWrapper:
class SamplingContext:
"""
PySAGES simulation context. Manages access to the backend-dependent simulation context.
"""

def __init__(self, context, sampling_method, callback: Callable = None, **kwargs):
def __init__(
self,
sampling_method,
context_generator: Callable,
callback: Optional[Callable] = None,
context_args: dict = {},
**kwargs,
):
"""
Automatically identifies the backend and binds the sampling method to
the simulation context.
"""
self._backend_name = None
context = context_generator(**context_args)
module_name = type(context).__module__

if module_name.startswith("ase.md"):
self._backend_name = "ase"
elif module_name.startswith("hoomd"):
self._backend_name = "hoomd"
elif isinstance(context, JaxMDContext):
self._backend_name = "jax-md"
elif module_name.startswith("simtk.openmm") or module_name.startswith("openmm"):
self._backend_name = "openmm"
elif isinstance(context, JaxMDContext):
self._backend_name = "jax-md"

if self._backend_name is not None:
self._backend = import_module("." + self._backend_name, package="pysages.backends")
else:
if self._backend_name is None:
backends = ", ".join(supported_backends())
raise ValueError(
f"Invalid backend {self._backend_name}: supported options are ({backends})"
)
raise ValueError(f"Invalid backend {module_name}: supported options are ({backends})")

self.context = context
self.method = sampling_method
self.view = None
self.run = None
self.sampler = self._backend.bind(self, sampling_method, callback, **kwargs)

backend = import_module("." + self._backend_name, package="pysages.backends")
self.sampler = backend.bind(self, callback, **kwargs)

# `self.view` and `self.run` *must* be set by the backend bind function.
assert self.view is not None
assert self.run is not None

self.synchronize = self.view.synchronize

def get_backend_name(self):
@property
def backend_name(self):
return self._backend_name

def get_backend_module(self):
return self._backend

def __enter__(self):
"""
Trampoline 'with statements' to the wrapped context when the backend supports it.
Expand Down
24 changes: 11 additions & 13 deletions pysages/backends/hoomd.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from jax import numpy as np
from jax.dlpack import from_dlpack as asarray

from pysages.backends.core import ContextWrapper
from pysages.backends.core import SamplingContext
from pysages.backends.snapshot import (
Box,
HelperMethods,
Expand All @@ -32,7 +32,6 @@
build_data_querier,
)
from pysages.backends.snapshot import restore as _restore
from pysages.methods import SamplingMethod
from pysages.utils import check_device_array, copy

# TODO: Figure out a way to automatically tie the lifetime of Sampler
Expand Down Expand Up @@ -101,9 +100,9 @@ def update(positions, vel_mass, rtags, images, forces, timestep):

super().__init__(sysview, update, default_location(), AccessMode.Read)
self.state = initialize()
self.callback = callback
self.bias = bias
self.box = initial_snapshot.box
self.callback = callback
self.dt = initial_snapshot.dt
self._restore = restore

Expand Down Expand Up @@ -147,9 +146,9 @@ def default_location():
return AccessLocation.OnHost


def take_snapshot(wrapped_context, location=default_location()):
context = wrapped_context.context
sysview = wrapped_context.view
def take_snapshot(sampling_context, location=default_location()):
context = sampling_context.context
sysview = sampling_context.view
positions = copy(asarray(positions_types(sysview, location, AccessMode.Read)))
vel_mass = copy(asarray(velocities_masses(sysview, location, AccessMode.Read)))
forces = copy(asarray(net_forces(sysview, location, AccessMode.ReadWrite)))
Expand Down Expand Up @@ -241,17 +240,16 @@ def dimensionality():
return helpers, restore, bias


def bind(
wrapped_context: ContextWrapper, sampling_method: SamplingMethod, callback: Callable, **kwargs
):
context = wrapped_context.context
def bind(sampling_context: SamplingContext, callback: Callable, **kwargs):
context = sampling_context.context
sampling_method = sampling_context.method
sysview = SystemView(get_system(context))
wrapped_context.view = sysview
wrapped_context.run = get_run_method(context)
sampling_context.view = sysview
sampling_context.run = get_run_method(context)
helpers, restore, bias = build_helpers(context, sampling_method)

with sysview:
snapshot = take_snapshot(wrapped_context)
snapshot = take_snapshot(sampling_context)

method_bundle = sampling_method.build(snapshot, helpers)
sync_and_bias = partial(bias, sync_backend=sysview.synchronize)
Expand Down
Loading