Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SamplerState interface to OpenMM CustomCVForce #362

Merged
merged 5 commits into from
Jul 23, 2018
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions devtools/conda-recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,20 @@ build:
requirements:
build:
- python
- cython
- numpy
- scipy
# - cython
# - numpy
# - scipy
- setuptools
- openmm
- parmed
- mdtraj
- netcdf4
- pyyaml
# - parmed
# - mdtraj
# - netcdf4
# - pyyaml

run:
- python
- cython
- numpy
- scipy
- setuptools
- six
- openmm
- parmed
Expand Down
7 changes: 7 additions & 0 deletions docs/releasehistory.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
Release History
===============

0.15.1 - Something New This Way Comes [WIP]
===========================================

New features
------------
- Add ability for `SamplerState` to access new `OpenMM Custom CV Force Variables <http://docs.openmm.org/development/api-python/generated/simtk.openmm.openmm.CustomCVForce.html#simtk.openmm.openmm.CustomCVForce>`_

0.15.0 - Restraint forces
=========================
- Add radially-symmetric restraint custom forces (`#336 <https://github.com/choderalab/openmmtools/pull/336>`_).
Expand Down
105 changes: 87 additions & 18 deletions openmmtools/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import copy
import zlib
import weakref
import collections

import numpy as np
from simtk import openmm, unit
Expand Down Expand Up @@ -1618,6 +1619,7 @@ class SamplerState(object):
total_energy
volume
n_particles
collective_variables

Examples
--------
Expand Down Expand Up @@ -1676,6 +1678,13 @@ class SamplerState(object):
# -------------------------------------------------------------------------

def __init__(self, positions, velocities=None, box_vectors=None):
# Allocate variables, they get set in _initialize
self._positions = None
self._velocities = None
self._box_vectors = None
self._collective_variables = None
self._kinetic_energy = None
self._potential_energy = None
self._initialize(copy.deepcopy(positions), copy.deepcopy(velocities), copy.deepcopy(box_vectors))

@classmethod
Expand Down Expand Up @@ -1760,16 +1769,18 @@ def box_vectors(self, value):
value = unit.Quantity(value)
self._box_vectors = value

# Derived properties

@property
def potential_energy(self):
"""simtk.unit.Quantity or None: Potential energy of this configuration."""
if self.positions is None or self.positions.has_changed:
if self._test_positions_valid:
return None
return self._potential_energy

@potential_energy.setter
def potential_energy(self, new_value):
self._potential_energy = new_value
raise AttributeError("Cannot set potential energy as it is a function of Context")

@property
def kinetic_energy(self):
Expand All @@ -1780,7 +1791,18 @@ def kinetic_energy(self):

@kinetic_energy.setter
def kinetic_energy(self, new_value):
self._kinetic_energy = new_value
raise AttributeError("Cannot set kinetic energy as it is a function of Context")

@property
def collective_variables(self):
"""dict or None: Collective variables for this configuration if present in Context"""
if self._test_positions_valid:
return None
return self._collective_variables

@collective_variables.setter
def collective_variables(self, new_value):
raise AttributeError("Cannot set collective variables as it is a function of Context")

@property
def total_energy(self):
Expand Down Expand Up @@ -1833,7 +1855,9 @@ def update_from_context(self, context_state):
----------
context_state : simtk.openmm.Context or simtk.openmm.State
The object to read. If a State, it must contain information
on positions, velocities and energies.
on positions, velocities and energies. Collective
variables can only be updated from a Context, NOT a State
at the moment. If a State is provided, collective variables are nullified.

Raises
------
Expand All @@ -1843,7 +1867,7 @@ def update_from_context(self, context_state):
"""
self._read_context_state(context_state, check_consistency=True)

def apply_to_context(self, context, ignore_velocities=False):
def apply_to_context(self, context, ignore_velocities=False, ignore_collective_variables=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was remembering incorrectly and I thought there was an ignore_velocities in update_from_context() on top of apply_to_context. I wouldn't have an ignore_collective_variables argument here anyway. This function task is to write the context so it shouldn't also read information from it, and collective variables like the energies, are not written.

I'm now thinking we should also add optional ignore_positions, ignore_velocities, ignore_energies, and ignore_collective_variables argument in update_from_context(), similarly to OpenMM's Context.getState() to allow users to optimize reading off the GPU. We'll definitely need the ignore_collective_variables in any case.

"""Set the context state.

If velocities and box vectors have not been specified in the
Expand All @@ -1857,6 +1881,10 @@ def apply_to_context(self, context, ignore_velocities=False):
If True, velocities are not set in the Context even if they
are defined. This can be useful if you only need to use the
Context only to compute energies.
ignore_collective_variables : bool, optional
If True, the collective variables are not updated from the
Context, so the SamplerState's internal tracking
of their values remains unchanged

"""
# NOTE: Box vectors MUST be updated before positions are set.
Expand All @@ -1865,6 +1893,8 @@ def apply_to_context(self, context, ignore_velocities=False):
context.setPositions(self._unitless_positions)
if self._velocities is not None and not ignore_velocities:
context.setVelocities(self._unitless_velocities)
if not ignore_collective_variables:
self._set_collective_variables(context)

def has_nan(self):
"""Check that energies and positions are finite.
Expand Down Expand Up @@ -1906,17 +1936,19 @@ def __getitem__(self, item):
# Copy box vectors.
sampler_state.box_vectors = copy.deepcopy(self.box_vectors)

# Energies for only a subset of atoms is undefined.
sampler_state.potential_energy = None
sampler_state.kinetic_energy = None
# Energies/CV's for only a subset of atoms is undefined.
sampler_state._potential_energy = None
sampler_state._kinetic_energy = None
sampler_state._collective_variables = None
return sampler_state

def __getstate__(self):
"""Return a dictionary representation of the state."""
serialization = dict(
positions=self.positions, velocities=self.velocities,
box_vectors=self.box_vectors, potential_energy=self.potential_energy,
kinetic_energy=self.kinetic_energy
kinetic_energy=self.kinetic_energy,
collective_variables=self.collective_variables
)
return serialization

Expand All @@ -1929,13 +1961,14 @@ def __setstate__(self, serialization):
# -------------------------------------------------------------------------

def _initialize(self, positions, velocities, box_vectors,
potential_energy=None, kinetic_energy=None):
potential_energy=None, kinetic_energy=None, collective_variables=None):
"""Initialize the sampler state."""
self._set_positions(positions, from_context=False, check_consistency=False)
self.velocities = velocities # Checks consistency and units.
self.box_vectors = box_vectors # Make sure box vectors is Quantity.
self.potential_energy = potential_energy
self.kinetic_energy = kinetic_energy
self._potential_energy = potential_energy
self._kinetic_energy = kinetic_energy
self._collective_variables = collective_variables

def _set_positions(self, new_positions, from_context, check_consistency):
"""Set the positions without checking for consistency."""
Expand All @@ -1951,7 +1984,10 @@ def _set_positions(self, new_positions, from_context, check_consistency):
self._positions = utils.TrackedQuantity(new_positions)

# The potential energy changes with different positions.
self.potential_energy = None
self._potential_energy = None

# The CVs change with different positions too
self._collective_variables = None

def _set_velocities(self, new_velocities, from_context):
"""Set the velocities."""
Expand All @@ -1968,7 +2004,7 @@ def _set_velocities(self, new_velocities, from_context):
self._velocities = new_velocities

# The kinetic energy changes with different positions.
self.kinetic_energy = None
self._kinetic_energy = None

@property
def _unitless_positions(self):
Expand All @@ -1977,7 +2013,7 @@ def _unitless_positions(self):
self._unitless_positions_cache = self.positions.value_in_unit_system(unit.md_unit_system)
if self._positions.has_changed:
self._positions.has_changed = False
self.potential_energy = None
self._potential_energy = None
return self._unitless_positions_cache

@property
Expand All @@ -1989,7 +2025,7 @@ def _unitless_velocities(self):
self._unitless_velocities_cache = self._velocities.value_in_unit_system(unit.md_unit_system)
if self._velocities.has_changed:
self._velocities.has_changed = False
self.kinetic_energy = None
self._kinetic_energy = None
return self._unitless_velocities_cache

def _read_context_state(self, context_state, check_consistency):
Expand Down Expand Up @@ -2027,8 +2063,41 @@ def _read_context_state(self, context_state, check_consistency):
self.box_vectors = openmm_state.getPeriodicBoxVectors(asNumpy=True)
# Potential energy and kinetic energy must be updated
# after positions and velocities or they'll be reset.
self.potential_energy = openmm_state.getPotentialEnergy()
self.kinetic_energy = openmm_state.getKineticEnergy()
self._potential_energy = openmm_state.getPotentialEnergy()
self._kinetic_energy = openmm_state.getKineticEnergy()
self._set_collective_variables(context_state)

def _set_collective_variables(self, context_state):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you call this maybe _read_collective_variables to be consistent with _read_context_state?

"""
Update the collective variables from the context object

Parameters
----------
context_state : simtk.openmm.Context or simtk.openmm.State
The object to read. This only works with Context's for now,
but in the future, this may support OpenMM State objects as well.
"""
if isinstance(context_state, openmm.State):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should raise an exception if the user passes an openmm.State object and get_collective_variables=True. Otherwise, this will fail silently and end up with a value for SamplerState.collective_variables indistinguishable from having a System with no collective variables.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really like this idea. This can lead the user to be mandated to use the get_collective_variables=False in every call to this function when the collective variables are less common.

SamplerState.collective_variables indistinguishable from having a System with no collective variables.

Yes, though I thought that was partially the point? If a user never supplies the information about a given state, then we can't use that. I don't like the idea of the users being required to now track collective variables, which are an optional construct in of themselves.

Copy link
Contributor

@andrrizzi andrrizzi Jul 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The user wouldn't have to set get_collective_variable=False in every call as long as he passes an actual Context object. Passing the State object is there exclusively for optimization (i.e. if I already have one around I don't have to read the info off the Context again).

In any case, I think avoiding silent failing would be worth forcing the users to add the extra option, and the default behavior should be the safest. With this implementation, for example, all our metropolized MCMCMoves are currently failing silently and not retrieving CV variables that should actually be available after apply has been performed.

self._collective_variables = None
return
# Allows direct key assignment without initializing each key:dict pair
collective_variables = collections.defaultdict(dict)
system = context_state.getSystem()
for force_index, force in enumerate(system.getForces()):
try:
cv_values = force.getCollectiveVariableValues(context_state)
for cv_index in range(force.getNumCollectiveVariables()):
cv_name = force.getCollectiveVariableName(cv_index)
collective_variables[cv_name][force_index] = cv_values[cv_index]
except AttributeError:
pass
# Trap no variables found (empty dict), return None
self._collective_variables = collective_variables if collective_variables else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would convert the defaultdict to a normal dict either here or in the SamplerState.collective_variable property since the user expects that type.


@property
def _test_positions_valid(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe rename this as _are_positions_valid, which I believe is the most common name convention for these.

"""Helper function to reduce this check duplication in multiple properties"""
return self.positions is None or self.positions.has_changed


# =============================================================================
Expand Down
70 changes: 57 additions & 13 deletions openmmtools/tests/test_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,12 +854,11 @@ def test_unitless_cache(self):
sampler_state.positions = copy.deepcopy(positions)
assert sampler_state._unitless_positions_cache is None

# TODO reactivate this test once OpenMM 7.2 is released with bugfix for #1940
# if isinstance(sampler_state._positions._value, np.ndarray):
# old_unitless_positions = copy.deepcopy(sampler_state._unitless_positions)
# sampler_state.positions[5:8] = [[2.0, 2.0, 2.0], [2.0, 2.0, 2.0], [2.0, 2.0, 2.0]] * pos_unit
# assert sampler_state.positions.has_changed
# assert np.all(old_unitless_positions[5:8] != sampler_state._unitless_positions[5:8])
if isinstance(sampler_state._positions._value, np.ndarray):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for re-activating these tests! I forgot about this.

old_unitless_positions = copy.deepcopy(sampler_state._unitless_positions)
sampler_state.positions[5:8] = [[2.0, 2.0, 2.0], [2.0, 2.0, 2.0], [2.0, 2.0, 2.0]] * pos_unit
assert sampler_state.positions.has_changed
assert np.all(old_unitless_positions[5:8] != sampler_state._unitless_positions[5:8])

if sampler_state.velocities is not None:
old_unitless_velocities = copy.deepcopy(sampler_state._unitless_velocities)
Expand All @@ -869,16 +868,14 @@ def test_unitless_cache(self):
sampler_state.velocities = copy.deepcopy(sampler_state.velocities)
assert sampler_state._unitless_velocities_cache is None

# TODO reactivate this test once OpenMM 7.2 is released with bugfix for #1940
# if isinstance(sampler_state._velocities._value, np.ndarray):
# old_unitless_velocities = copy.deepcopy(sampler_state._unitless_velocities)
# sampler_state.velocities[5:8] = [[2.0, 2.0, 2.0], [2.0, 2.0, 2.0], [2.0, 2.0, 2.0]] * vel_unit
# assert sampler_state.velocities.has_changed
# assert np.all(old_unitless_velocities[5:8] != sampler_state._unitless_velocities[5:8])
if isinstance(sampler_state._velocities._value, np.ndarray):
old_unitless_velocities = copy.deepcopy(sampler_state._unitless_velocities)
sampler_state.velocities[5:8] = [[2.0, 2.0, 2.0], [2.0, 2.0, 2.0], [2.0, 2.0, 2.0]] * vel_unit
assert sampler_state.velocities.has_changed
assert np.all(old_unitless_velocities[5:8] != sampler_state._unitless_velocities[5:8])
else:
assert sampler_state._unitless_velocities is None


def test_method_is_context_compatible(self):
"""SamplerState.is_context_compatible() method."""
# Vacuum.
Expand Down Expand Up @@ -966,6 +963,53 @@ def test_operator_getitem(self):
assert sliced_sampler_state.kinetic_energy is None
assert sliced_sampler_state.potential_energy is None

def test_collective_variable(self):
"""Test that CV calculation is working (If on OpenMM >=7.3)"""
# TODO: Remove the if statement and require OpenMM 7.3 once 7.3 is actually released
if not hasattr(openmm, "CustomCVForce"):
return
# Setup the CV tests if we have a late enough OpenMM
# alanine_explicit_cv = copy.deepcopy(self.alanine_explicit)
system_cv = self.alanine_explicit_state.system
cv_distance = openmm.CustomBondForce("r")
cv_distance.addBond(0, 1, [])
cv_angle = openmm.CustomAngleForce("theta")
cv_angle.addAngle(0, 1, 2, [])
# 3 unique CV names in the Context: BondCV, AngleCVSingle, AngleCV
cv_single_1 = openmm.CustomCVForce("4*BondCV")
# We are going to use this name later too
cv_single_1.addCollectiveVariable('BondCV', copy.deepcopy(cv_distance))
cv_single_2 = openmm.CustomCVForce("sin(AngleCVSingle)") # This is suppose to be unique
cv_single_2.addCollectiveVariable('AngleCVSingle', copy.deepcopy(cv_angle))
cv_combined = openmm.CustomCVForce("4*BondCV + sin(AngleCV)")
cv_combined.addCollectiveVariable("BondCV", cv_distance)
cv_combined.addCollectiveVariable("AngleCV", cv_angle)
for force in [cv_single_1, cv_single_2, cv_combined]:
system_cv.addForce(force)
thermo_state = ThermodynamicState(system_cv, self.alanine_explicit_state.temperature)
context = self.create_context(thermo_state)
context.setPositions(self.alanine_explicit_positions)
sampler_state = SamplerState.from_context(context)
collective_variables = sampler_state.collective_variables
name_count = (('BondCV', 2), ('AngleCV', 1), ('AngleCVSingle', 1))
# Ensure the CV's are all accounted for
assert len(collective_variables.keys()) == 3
for name, count in name_count:
# Ensure the CV's show up in the Context the number of times we expect them to
assert len(collective_variables[name].keys()) == count
# Ensure CVs which are the same in different forces are equal
assert len(set(collective_variables['BondCV'].values())) == 1 # Cast values of CV to set, make sure len == 1
# Ensure invalidation with single replacement
new_pos = copy.deepcopy(self.alanine_explicit_positions)
new_pos[0] *= 2
sampler_state.positions[0] = new_pos[0]
assert sampler_state.collective_variables is None
# Ensure CV's are read from context
sampler_state.update_from_context(context)
assert sampler_state.collective_variables is not None
# Ensure invalidation with full variable swap
sampler_state.positions = new_pos
assert sampler_state.collective_variables is None

# =============================================================================
# TEST COMPOUND STATE
Expand Down
8 changes: 1 addition & 7 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,8 @@
Various Python tools for OpenMM.
"""
from __future__ import print_function
import os
import sys
from distutils.core import setup, Extension
from setuptools import setup, Extension, find_packages
import numpy
import glob
from setuptools import setup
import os
from os.path import relpath, join
import subprocess
Expand Down Expand Up @@ -171,8 +167,6 @@ def check_dependencies():
packages=['openmmtools', 'openmmtools.tests', 'openmmtools.scripts', 'openmmtools.storage'],
package_dir={'openmmtools': 'openmmtools'},
package_data={'openmmtools': find_package_data('openmmtools/data', 'openmmtools')},
install_requires=['numpy', 'scipy', 'openmm', 'parmed', 'mdtraj', 'netCDF4', 'pyyaml'],
tests_requires=['nose', 'pymbar', 'netCDF4', 'pyyaml'],
zip_safe=False,
scripts=[],
ext_modules=extensions,
Expand Down