-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add SamplerState interface to OpenMM CustomCVForce #362
Changes from 2 commits
ccc4a2e
46ec139
3ed678d
574ba24
8e3d46b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
import copy | ||
import zlib | ||
import weakref | ||
import collections | ||
|
||
import numpy as np | ||
from simtk import openmm, unit | ||
|
@@ -1618,6 +1619,7 @@ class SamplerState(object): | |
total_energy | ||
volume | ||
n_particles | ||
collective_variables | ||
|
||
Examples | ||
-------- | ||
|
@@ -1676,6 +1678,13 @@ class SamplerState(object): | |
# ------------------------------------------------------------------------- | ||
|
||
def __init__(self, positions, velocities=None, box_vectors=None): | ||
# Allocate variables, they get set in _initialize | ||
self._positions = None | ||
self._velocities = None | ||
self._box_vectors = None | ||
self._collective_variables = None | ||
self._kinetic_energy = None | ||
self._potential_energy = None | ||
self._initialize(copy.deepcopy(positions), copy.deepcopy(velocities), copy.deepcopy(box_vectors)) | ||
|
||
@classmethod | ||
|
@@ -1760,16 +1769,18 @@ def box_vectors(self, value): | |
value = unit.Quantity(value) | ||
self._box_vectors = value | ||
|
||
# Derived properties | ||
|
||
@property | ||
def potential_energy(self): | ||
"""simtk.unit.Quantity or None: Potential energy of this configuration.""" | ||
if self.positions is None or self.positions.has_changed: | ||
if self._test_positions_valid: | ||
return None | ||
return self._potential_energy | ||
|
||
@potential_energy.setter | ||
def potential_energy(self, new_value): | ||
self._potential_energy = new_value | ||
raise AttributeError("Cannot set potential energy as it is a function of Context") | ||
|
||
@property | ||
def kinetic_energy(self): | ||
|
@@ -1780,7 +1791,18 @@ def kinetic_energy(self): | |
|
||
@kinetic_energy.setter | ||
def kinetic_energy(self, new_value): | ||
self._kinetic_energy = new_value | ||
raise AttributeError("Cannot set kinetic energy as it is a function of Context") | ||
|
||
@property | ||
def collective_variables(self): | ||
"""dict or None: Collective variables for this configuration if present in Context""" | ||
if self._test_positions_valid: | ||
return None | ||
return self._collective_variables | ||
|
||
@collective_variables.setter | ||
def collective_variables(self, new_value): | ||
raise AttributeError("Cannot set collective variables as it is a function of Context") | ||
|
||
@property | ||
def total_energy(self): | ||
|
@@ -1833,7 +1855,9 @@ def update_from_context(self, context_state): | |
---------- | ||
context_state : simtk.openmm.Context or simtk.openmm.State | ||
The object to read. If a State, it must contain information | ||
on positions, velocities and energies. | ||
on positions, velocities and energies. Collective | ||
variables can only be updated from a Context, NOT a State | ||
at the moment. If a State is provided, collective variables are nullified. | ||
|
||
Raises | ||
------ | ||
|
@@ -1843,7 +1867,7 @@ def update_from_context(self, context_state): | |
""" | ||
self._read_context_state(context_state, check_consistency=True) | ||
|
||
def apply_to_context(self, context, ignore_velocities=False): | ||
def apply_to_context(self, context, ignore_velocities=False, ignore_collective_variables=False): | ||
"""Set the context state. | ||
|
||
If velocities and box vectors have not been specified in the | ||
|
@@ -1857,6 +1881,10 @@ def apply_to_context(self, context, ignore_velocities=False): | |
If True, velocities are not set in the Context even if they | ||
are defined. This can be useful if you only need to use the | ||
Context only to compute energies. | ||
ignore_collective_variables : bool, optional | ||
If True, the collective variables are not updated from the | ||
Context, so the SamplerState's internal tracking | ||
of their values remains unchanged | ||
|
||
""" | ||
# NOTE: Box vectors MUST be updated before positions are set. | ||
|
@@ -1865,6 +1893,8 @@ def apply_to_context(self, context, ignore_velocities=False): | |
context.setPositions(self._unitless_positions) | ||
if self._velocities is not None and not ignore_velocities: | ||
context.setVelocities(self._unitless_velocities) | ||
if not ignore_collective_variables: | ||
self._set_collective_variables(context) | ||
|
||
def has_nan(self): | ||
"""Check that energies and positions are finite. | ||
|
@@ -1906,17 +1936,19 @@ def __getitem__(self, item): | |
# Copy box vectors. | ||
sampler_state.box_vectors = copy.deepcopy(self.box_vectors) | ||
|
||
# Energies for only a subset of atoms is undefined. | ||
sampler_state.potential_energy = None | ||
sampler_state.kinetic_energy = None | ||
# Energies/CV's for only a subset of atoms is undefined. | ||
sampler_state._potential_energy = None | ||
sampler_state._kinetic_energy = None | ||
sampler_state._collective_variables = None | ||
return sampler_state | ||
|
||
def __getstate__(self): | ||
"""Return a dictionary representation of the state.""" | ||
serialization = dict( | ||
positions=self.positions, velocities=self.velocities, | ||
box_vectors=self.box_vectors, potential_energy=self.potential_energy, | ||
kinetic_energy=self.kinetic_energy | ||
kinetic_energy=self.kinetic_energy, | ||
collective_variables=self.collective_variables | ||
) | ||
return serialization | ||
|
||
|
@@ -1929,13 +1961,14 @@ def __setstate__(self, serialization): | |
# ------------------------------------------------------------------------- | ||
|
||
def _initialize(self, positions, velocities, box_vectors, | ||
potential_energy=None, kinetic_energy=None): | ||
potential_energy=None, kinetic_energy=None, collective_variables=None): | ||
"""Initialize the sampler state.""" | ||
self._set_positions(positions, from_context=False, check_consistency=False) | ||
self.velocities = velocities # Checks consistency and units. | ||
self.box_vectors = box_vectors # Make sure box vectors is Quantity. | ||
self.potential_energy = potential_energy | ||
self.kinetic_energy = kinetic_energy | ||
self._potential_energy = potential_energy | ||
self._kinetic_energy = kinetic_energy | ||
self._collective_variables = collective_variables | ||
|
||
def _set_positions(self, new_positions, from_context, check_consistency): | ||
"""Set the positions without checking for consistency.""" | ||
|
@@ -1951,7 +1984,10 @@ def _set_positions(self, new_positions, from_context, check_consistency): | |
self._positions = utils.TrackedQuantity(new_positions) | ||
|
||
# The potential energy changes with different positions. | ||
self.potential_energy = None | ||
self._potential_energy = None | ||
|
||
# The CVs change with different positions too | ||
self._collective_variables = None | ||
|
||
def _set_velocities(self, new_velocities, from_context): | ||
"""Set the velocities.""" | ||
|
@@ -1968,7 +2004,7 @@ def _set_velocities(self, new_velocities, from_context): | |
self._velocities = new_velocities | ||
|
||
# The kinetic energy changes with different positions. | ||
self.kinetic_energy = None | ||
self._kinetic_energy = None | ||
|
||
@property | ||
def _unitless_positions(self): | ||
|
@@ -1977,7 +2013,7 @@ def _unitless_positions(self): | |
self._unitless_positions_cache = self.positions.value_in_unit_system(unit.md_unit_system) | ||
if self._positions.has_changed: | ||
self._positions.has_changed = False | ||
self.potential_energy = None | ||
self._potential_energy = None | ||
return self._unitless_positions_cache | ||
|
||
@property | ||
|
@@ -1989,7 +2025,7 @@ def _unitless_velocities(self): | |
self._unitless_velocities_cache = self._velocities.value_in_unit_system(unit.md_unit_system) | ||
if self._velocities.has_changed: | ||
self._velocities.has_changed = False | ||
self.kinetic_energy = None | ||
self._kinetic_energy = None | ||
return self._unitless_velocities_cache | ||
|
||
def _read_context_state(self, context_state, check_consistency): | ||
|
@@ -2027,8 +2063,41 @@ def _read_context_state(self, context_state, check_consistency): | |
self.box_vectors = openmm_state.getPeriodicBoxVectors(asNumpy=True) | ||
# Potential energy and kinetic energy must be updated | ||
# after positions and velocities or they'll be reset. | ||
self.potential_energy = openmm_state.getPotentialEnergy() | ||
self.kinetic_energy = openmm_state.getKineticEnergy() | ||
self._potential_energy = openmm_state.getPotentialEnergy() | ||
self._kinetic_energy = openmm_state.getKineticEnergy() | ||
self._set_collective_variables(context_state) | ||
|
||
def _set_collective_variables(self, context_state): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you call this maybe |
||
""" | ||
Update the collective variables from the context object | ||
|
||
Parameters | ||
---------- | ||
context_state : simtk.openmm.Context or simtk.openmm.State | ||
The object to read. This only works with Context's for now, | ||
but in the future, this may support OpenMM State objects as well. | ||
""" | ||
if isinstance(context_state, openmm.State): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should raise an exception if the user passes an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't really like this idea. This can lead the user to be mandated to use the
Yes, though I thought that was partially the point? If a user never supplies the information about a given state, then we can't use that. I don't like the idea of the users being required to now track collective variables, which are an optional construct in of themselves. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The user wouldn't have to set In any case, I think avoiding silent failing would be worth forcing the users to add the extra option, and the default behavior should be the safest. With this implementation, for example, all our metropolized |
||
self._collective_variables = None | ||
return | ||
# Allows direct key assignment without initializing each key:dict pair | ||
collective_variables = collections.defaultdict(dict) | ||
system = context_state.getSystem() | ||
for force_index, force in enumerate(system.getForces()): | ||
try: | ||
cv_values = force.getCollectiveVariableValues(context_state) | ||
for cv_index in range(force.getNumCollectiveVariables()): | ||
cv_name = force.getCollectiveVariableName(cv_index) | ||
collective_variables[cv_name][force_index] = cv_values[cv_index] | ||
except AttributeError: | ||
pass | ||
# Trap no variables found (empty dict), return None | ||
self._collective_variables = collective_variables if collective_variables else None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would convert the |
||
|
||
@property | ||
def _test_positions_valid(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe rename this as |
||
"""Helper function to reduce this check duplication in multiple properties""" | ||
return self.positions is None or self.positions.has_changed | ||
|
||
|
||
# ============================================================================= | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -854,12 +854,11 @@ def test_unitless_cache(self): | |
sampler_state.positions = copy.deepcopy(positions) | ||
assert sampler_state._unitless_positions_cache is None | ||
|
||
# TODO reactivate this test once OpenMM 7.2 is released with bugfix for #1940 | ||
# if isinstance(sampler_state._positions._value, np.ndarray): | ||
# old_unitless_positions = copy.deepcopy(sampler_state._unitless_positions) | ||
# sampler_state.positions[5:8] = [[2.0, 2.0, 2.0], [2.0, 2.0, 2.0], [2.0, 2.0, 2.0]] * pos_unit | ||
# assert sampler_state.positions.has_changed | ||
# assert np.all(old_unitless_positions[5:8] != sampler_state._unitless_positions[5:8]) | ||
if isinstance(sampler_state._positions._value, np.ndarray): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for re-activating these tests! I forgot about this. |
||
old_unitless_positions = copy.deepcopy(sampler_state._unitless_positions) | ||
sampler_state.positions[5:8] = [[2.0, 2.0, 2.0], [2.0, 2.0, 2.0], [2.0, 2.0, 2.0]] * pos_unit | ||
assert sampler_state.positions.has_changed | ||
assert np.all(old_unitless_positions[5:8] != sampler_state._unitless_positions[5:8]) | ||
|
||
if sampler_state.velocities is not None: | ||
old_unitless_velocities = copy.deepcopy(sampler_state._unitless_velocities) | ||
|
@@ -869,16 +868,14 @@ def test_unitless_cache(self): | |
sampler_state.velocities = copy.deepcopy(sampler_state.velocities) | ||
assert sampler_state._unitless_velocities_cache is None | ||
|
||
# TODO reactivate this test once OpenMM 7.2 is released with bugfix for #1940 | ||
# if isinstance(sampler_state._velocities._value, np.ndarray): | ||
# old_unitless_velocities = copy.deepcopy(sampler_state._unitless_velocities) | ||
# sampler_state.velocities[5:8] = [[2.0, 2.0, 2.0], [2.0, 2.0, 2.0], [2.0, 2.0, 2.0]] * vel_unit | ||
# assert sampler_state.velocities.has_changed | ||
# assert np.all(old_unitless_velocities[5:8] != sampler_state._unitless_velocities[5:8]) | ||
if isinstance(sampler_state._velocities._value, np.ndarray): | ||
old_unitless_velocities = copy.deepcopy(sampler_state._unitless_velocities) | ||
sampler_state.velocities[5:8] = [[2.0, 2.0, 2.0], [2.0, 2.0, 2.0], [2.0, 2.0, 2.0]] * vel_unit | ||
assert sampler_state.velocities.has_changed | ||
assert np.all(old_unitless_velocities[5:8] != sampler_state._unitless_velocities[5:8]) | ||
else: | ||
assert sampler_state._unitless_velocities is None | ||
|
||
|
||
def test_method_is_context_compatible(self): | ||
"""SamplerState.is_context_compatible() method.""" | ||
# Vacuum. | ||
|
@@ -966,6 +963,53 @@ def test_operator_getitem(self): | |
assert sliced_sampler_state.kinetic_energy is None | ||
assert sliced_sampler_state.potential_energy is None | ||
|
||
def test_collective_variable(self): | ||
"""Test that CV calculation is working (If on OpenMM >=7.3)""" | ||
# TODO: Remove the if statement and require OpenMM 7.3 once 7.3 is actually released | ||
if not hasattr(openmm, "CustomCVForce"): | ||
return | ||
# Setup the CV tests if we have a late enough OpenMM | ||
# alanine_explicit_cv = copy.deepcopy(self.alanine_explicit) | ||
system_cv = self.alanine_explicit_state.system | ||
cv_distance = openmm.CustomBondForce("r") | ||
cv_distance.addBond(0, 1, []) | ||
cv_angle = openmm.CustomAngleForce("theta") | ||
cv_angle.addAngle(0, 1, 2, []) | ||
# 3 unique CV names in the Context: BondCV, AngleCVSingle, AngleCV | ||
cv_single_1 = openmm.CustomCVForce("4*BondCV") | ||
# We are going to use this name later too | ||
cv_single_1.addCollectiveVariable('BondCV', copy.deepcopy(cv_distance)) | ||
cv_single_2 = openmm.CustomCVForce("sin(AngleCVSingle)") # This is suppose to be unique | ||
cv_single_2.addCollectiveVariable('AngleCVSingle', copy.deepcopy(cv_angle)) | ||
cv_combined = openmm.CustomCVForce("4*BondCV + sin(AngleCV)") | ||
cv_combined.addCollectiveVariable("BondCV", cv_distance) | ||
cv_combined.addCollectiveVariable("AngleCV", cv_angle) | ||
for force in [cv_single_1, cv_single_2, cv_combined]: | ||
system_cv.addForce(force) | ||
thermo_state = ThermodynamicState(system_cv, self.alanine_explicit_state.temperature) | ||
context = self.create_context(thermo_state) | ||
context.setPositions(self.alanine_explicit_positions) | ||
sampler_state = SamplerState.from_context(context) | ||
collective_variables = sampler_state.collective_variables | ||
name_count = (('BondCV', 2), ('AngleCV', 1), ('AngleCVSingle', 1)) | ||
# Ensure the CV's are all accounted for | ||
assert len(collective_variables.keys()) == 3 | ||
for name, count in name_count: | ||
# Ensure the CV's show up in the Context the number of times we expect them to | ||
assert len(collective_variables[name].keys()) == count | ||
# Ensure CVs which are the same in different forces are equal | ||
assert len(set(collective_variables['BondCV'].values())) == 1 # Cast values of CV to set, make sure len == 1 | ||
# Ensure invalidation with single replacement | ||
new_pos = copy.deepcopy(self.alanine_explicit_positions) | ||
new_pos[0] *= 2 | ||
sampler_state.positions[0] = new_pos[0] | ||
assert sampler_state.collective_variables is None | ||
# Ensure CV's are read from context | ||
sampler_state.update_from_context(context) | ||
assert sampler_state.collective_variables is not None | ||
# Ensure invalidation with full variable swap | ||
sampler_state.positions = new_pos | ||
assert sampler_state.collective_variables is None | ||
|
||
# ============================================================================= | ||
# TEST COMPOUND STATE | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was remembering incorrectly and I thought there was an
ignore_velocities
inupdate_from_context()
on top ofapply_to_context
. I wouldn't have anignore_collective_variables
argument here anyway. This function task is to write the context so it shouldn't also read information from it, and collective variables like the energies, are not written.I'm now thinking we should also add optional
ignore_positions
,ignore_velocities
,ignore_energies
, andignore_collective_variables
argument inupdate_from_context()
, similarly to OpenMM'sContext.getState()
to allow users to optimize reading off the GPU. We'll definitely need theignore_collective_variables
in any case.