Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exact PME electrostatics treatment and fast computation of reduced potentials #320

Merged
merged 33 commits into from
Jan 26, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
9464924
Encapsulate parts of _alchemically_modify_NonbondedForce for readability
andrrizzi Jan 13, 2018
8514e0e
Implement exact treatment of PME electrostatics in absolute alchemica…
andrrizzi Jan 13, 2018
9447ba4
Add AlchemicalState support for exact PME treatment
andrrizzi Jan 15, 2018
af77671
Add tests exact PME electrostatics treatment
andrrizzi Jan 15, 2018
3fa4a73
Fix python 2 nested scope problem in tests
andrrizzi Jan 15, 2018
00d2df1
Try to fix the weird Travis error
andrrizzi Jan 15, 2018
4a0ea46
Add global variable lambda_electrostatics to original charges force f…
andrrizzi Jan 15, 2018
9cc6a05
Update charges in context only if lambda_electrostatics has changed
andrrizzi Jan 15, 2018
237dc87
Add option to split alchemical forces into force groups
andrrizzi Jan 16, 2018
a046875
Implement fast computation of reduced potential for a list of states.
andrrizzi Jan 16, 2018
908d2ca
Add AlchemicalState support for fast computation of reduced potentials
andrrizzi Jan 16, 2018
3214894
group_by_compatibility return also the original indices
andrrizzi Jan 16, 2018
429db1e
Fix nose tests
andrrizzi Jan 16, 2018
c9ccee9
Optimization of ThermodynamicState apply_to_context
andrrizzi Jan 17, 2018
c78366d
Add memoization to AlchemicalState._find_force_groups_to_update
andrrizzi Jan 17, 2018
3016f5e
Raise error if there are not enough force groups.
andrrizzi Jan 17, 2018
77816a3
Implemented utility class TrackedQuantity
andrrizzi Jan 19, 2018
a0b54ca
Add _on_setattr callback to ComposableStates
andrrizzi Jan 19, 2018
890e980
SamplerState caches unitless positions and velocities
andrrizzi Jan 19, 2018
b80dca6
Allow incompatible AlchemicalStates with exact PME
andrrizzi Jan 19, 2018
cf6afad
Bump dev version to 0.14.0
andrrizzi Jan 19, 2018
d18100f
Fix bug in force searching
andrrizzi Jan 19, 2018
8386c0d
Silence test until openmm 7.2 is released
andrrizzi Jan 19, 2018
0fcdddd
Whops! Re-fixing the bug
andrrizzi Jan 19, 2018
10a679d
Fix update standard system weakref cache after attribute setting
andrrizzi Jan 19, 2018
0415630
Fix changes in compatibility after attribute setting
andrrizzi Jan 19, 2018
e7b1610
Fix changes to the shared standard system
andrrizzi Jan 19, 2018
7ee1ed5
Fix callback
andrrizzi Jan 19, 2018
e7ac551
Merge branch 'master' into exact-pme
andrrizzi Jan 19, 2018
6515dde
Fix callback in direct-space PME case.
andrrizzi Jan 19, 2018
479e34a
Update releasehistory.rst
andrrizzi Jan 20, 2018
68298c2
Copy thermodynamic state on compound state initialization
andrrizzi Jan 23, 2018
15bf685
Backward-compatible AlchemicalState serialization
andrrizzi Jan 25, 2018
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docs/releasehistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ Development snapshot
New features
------------
- Add a ``WaterCluster`` testsystem (`#322 <https://github.com/choderalab/openmmtools/pull/322>`_)
- Add exact treatment of PME electrostatics in `alchemy.AbsoluteAlchemicalFactory`. (`#320 <https://github.com/choderalab/openmmtools/pull/320>`_)
- Add method in ``ThermodynamicState`` for the efficient computation of the reduced potential at a list of states. (`#320 <https://github.com/choderalab/openmmtools/pull/320>`_)
- When a ``SamplerState`` is applied to many ``Context``s, the units are stripped only once for optimization. (`#320 <https://github.com/choderalab/openmmtools/pull/320>`_)


0.13.4 - Barostat/External Force Bugfix, Restart Robustness
Expand Down Expand Up @@ -248,4 +251,4 @@ Updates to openmmtools.alchemy.AlchemicalFactory
New ``openmmtools.testsystems`` classes
---------------------------------------

- AlchemicalWaterBox was added, which has the first water molecule in the system alchemically modified
- AlchemicalWaterBox was added, which has the first water molecule in the system alchemically modified
864 changes: 655 additions & 209 deletions openmmtools/alchemy.py

Large diffs are not rendered by default.

617 changes: 464 additions & 153 deletions openmmtools/states.py

Large diffs are not rendered by default.

662 changes: 510 additions & 152 deletions openmmtools/tests/test_alchemy.py

Large diffs are not rendered by default.

150 changes: 133 additions & 17 deletions openmmtools/tests/test_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def setup_class(cls):
cls.alanine_positions = alanine_explicit.positions
cls.alanine_no_thermostat = alanine_explicit.system

cls.toluene_implicit = testsystems.TolueneImplicit().system
toluene_implicit = testsystems.TolueneImplicit()
cls.toluene_positions = toluene_implicit.positions
cls.toluene_implicit = toluene_implicit.system
cls.toluene_vacuum = testsystems.TolueneVacuum().system
thermostat = openmm.AndersenThermostat(cls.std_temperature,
1.0/unit.picosecond)
Expand Down Expand Up @@ -187,15 +189,6 @@ def test_method_is_barostat_consistent(self):
barostat = openmm.MonteCarloBarostat(pressure, temperature + 10*unit.kelvin)
assert not state._is_barostat_consistent(barostat)

def test_method_set_barostat_temperature(self):
"""ThermodynamicState._set_barostat_temperature() method."""
barostat = openmm.MonteCarloBarostat(self.std_pressure, self.std_temperature)
new_temperature = self.std_temperature + 10*unit.kelvin

assert ThermodynamicState._set_barostat_temperature(barostat, new_temperature)
assert get_barostat_temperature(barostat) == new_temperature
assert not ThermodynamicState._set_barostat_temperature(barostat, new_temperature)

def test_method_set_system_temperature(self):
"""ThermodynamicState._set_system_temperature() method."""
system = copy.deepcopy(self.alanine_no_thermostat)
Expand Down Expand Up @@ -260,7 +253,6 @@ def test_property_pressure_barostat(self):

# Correctly reads and set system pressures
periodic_testcases = [self.alanine_explicit]
print('ON IT!')
for system in periodic_testcases:
state = ThermodynamicState(system, self.std_temperature)
assert state.pressure is None
Expand Down Expand Up @@ -506,15 +498,15 @@ def test_method_set_integrator_temperature(self):

for thermostated, integrator in test_cases:
if thermostated:
state._set_integrator_temperature(integrator)
assert state._set_integrator_temperature(integrator)
for _integrator in ThermodynamicState._loop_over_integrators(integrator):
try:
assert _integrator.getTemperature() == new_temperature
except AttributeError: # handle CompoundIntegrator case
pass
else:
# It doesn't explode with integrators not coupled to a heat bath
state._set_integrator_temperature(integrator)
assert not state._set_integrator_temperature(integrator)

def test_method_standardize_system(self):
"""ThermodynamicState._standardize_system() class method."""
Expand Down Expand Up @@ -714,6 +706,47 @@ def test_method_reduced_potential(self):
state.reduced_potential(incompatible_sampler_state)
assert cm.exception.code == ThermodynamicsError.INCOMPATIBLE_SAMPLER_STATE

def test_method_reduced_potential_at_states(self):
"""ThermodynamicState.reduced_potential_at_states() method.

Computing the reduced potential singularly and with the class
method should give the same result.
"""
# Build a mixed collection of compatible and incompatible thermodynamic states.
thermodynamic_states = [
ThermodynamicState(self.alanine_explicit, temperature=300*unit.kelvin,
pressure=1.0*unit.atmosphere),
ThermodynamicState(self.toluene_implicit, temperature=200*unit.kelvin),
ThermodynamicState(self.alanine_explicit, temperature=250*unit.kelvin,
pressure=1.2*unit.atmosphere)
]

# Group thermodynamic states by compatibility.
compatible_groups, original_indices = group_by_compatibility(thermodynamic_states)
assert len(compatible_groups) == 2
assert original_indices == [[0, 2], [1]]

# Compute the reduced potentials.
expected_energies = []
obtained_energies = []
for compatible_group in compatible_groups:
# Create context.
integrator = openmm.VerletIntegrator(2.0*unit.femtoseconds)
context = compatible_group[0].create_context(integrator)
if len(compatible_group) == 2:
context.setPositions(self.alanine_positions)
else:
context.setPositions(self.toluene_positions)

# Compute with single-state method.
for state in compatible_group:
state.apply_to_context(context)
expected_energies.append(state.reduced_potential(context))

# Compute with multi-state method.
obtained_energies.extend(ThermodynamicState.reduced_potential_at_states(context, compatible_group))
assert np.allclose(np.array(expected_energies), np.array(obtained_energies))


# =============================================================================
# TEST SAMPLER STATE
Expand Down Expand Up @@ -796,6 +829,55 @@ def test_constructor_from_context(self):
sampler_state = SamplerState.from_context(alanine_vacuum_context)
assert self.is_sampler_state_equal_context(sampler_state, alanine_vacuum_context)

def test_unitless_cache(self):
"""Test that the unitless cache for positions and velocities is invalidated."""
positions = copy.deepcopy(self.alanine_vacuum_positions)

alanine_vacuum_context = self.create_context(self.alanine_vacuum_state)
alanine_vacuum_context.setPositions(copy.deepcopy(positions))

test_cases = [
SamplerState(positions),
SamplerState.from_context(alanine_vacuum_context)
]

pos_unit = unit.micrometer
vel_unit = unit.micrometer / unit.nanosecond

# Assigning an item invalidates the cache.
for sampler_state in test_cases:
old_unitless_positions = copy.deepcopy(sampler_state._unitless_positions)
sampler_state.positions[5] = [1.0, 1.0, 1.0] * pos_unit
assert sampler_state.positions.has_changed
assert np.all(old_unitless_positions[5] != sampler_state._unitless_positions[5])
sampler_state.positions = copy.deepcopy(positions)
assert sampler_state._unitless_positions_cache is None

# TODO reactivate this test once OpenMM 7.2 is released with bugfix for #1940
# if isinstance(sampler_state._positions._value, np.ndarray):
# old_unitless_positions = copy.deepcopy(sampler_state._unitless_positions)
# sampler_state.positions[5:8] = [[2.0, 2.0, 2.0], [2.0, 2.0, 2.0], [2.0, 2.0, 2.0]] * pos_unit
# assert sampler_state.positions.has_changed
# assert np.all(old_unitless_positions[5:8] != sampler_state._unitless_positions[5:8])

if sampler_state.velocities is not None:
old_unitless_velocities = copy.deepcopy(sampler_state._unitless_velocities)
sampler_state.velocities[5] = [1.0, 1.0, 1.0] * vel_unit
assert sampler_state.velocities.has_changed
assert np.all(old_unitless_velocities[5] != sampler_state._unitless_velocities[5])
sampler_state.velocities = copy.deepcopy(sampler_state.velocities)
assert sampler_state._unitless_velocities_cache is None

# TODO reactivate this test once OpenMM 7.2 is released with bugfix for #1940
# if isinstance(sampler_state._velocities._value, np.ndarray):
# old_unitless_velocities = copy.deepcopy(sampler_state._unitless_velocities)
# sampler_state.velocities[5:8] = [[2.0, 2.0, 2.0], [2.0, 2.0, 2.0], [2.0, 2.0, 2.0]] * vel_unit
# assert sampler_state.velocities.has_changed
# assert np.all(old_unitless_velocities[5:8] != sampler_state._unitless_velocities[5:8])
else:
assert sampler_state._unitless_velocities is None


def test_method_is_context_compatible(self):
"""SamplerState.is_context_compatible() method."""
# Vacuum.
Expand Down Expand Up @@ -907,10 +989,9 @@ def dummy_parameter(self):
def dummy_parameter(self, value):
self._dummy_parameter = value

@classmethod
def _standardize_system(cls, system):
def _standardize_system(self, system):
try:
cls.set_dummy_parameter(system, cls.standard_dummy_parameter)
self.set_dummy_parameter(system, self.standard_dummy_parameter)
except TypeError: # No parameter to set.
raise ComposableStateError()

Expand All @@ -933,11 +1014,22 @@ def is_context_compatible(context):
def apply_to_context(self, context):
context.setParameter('dummy_parameter', self.dummy_parameter)

def _on_setattr(self, standard_system, attribute_name):
return False

def _find_force_groups_to_update(self, context, current_context_state, memo):
if current_context_state.dummy_parameter == self.dummy_parameter:
return {}
force, _ = self._find_dummy_force(context.getSystem())
return {force.getForceGroup()}

@classmethod
def add_dummy_parameter(cls, system):
"""Add to system a CustomBondForce with a dummy parameter."""
force = openmm.CustomBondForce('dummy_parameter')
force.addGlobalParameter('dummy_parameter', cls.standard_dummy_parameter)
max_force_group = cls._find_max_force_group(system)
force.setForceGroup(max_force_group + 1)
system.addForce(force)

@staticmethod
Expand All @@ -954,6 +1046,14 @@ def set_dummy_parameter(cls, system, value):
force, parameter_id = cls._find_dummy_force(system)
force.setGlobalParameterDefaultValue(parameter_id, value)

@staticmethod
def _find_max_force_group(system):
max_force_group = 0
for force in system.getForces():
if max_force_group < force.getForceGroup():
max_force_group = force.getForceGroup()
return max_force_group

@classmethod
def get_dummy_parameter(cls, system):
force, parameter_id = cls.DummyState._find_dummy_force(system)
Expand Down Expand Up @@ -1071,7 +1171,7 @@ def test_method_standardize_system(self):
assert not compound_state.is_context_compatible(context)

def test_method_apply_to_context(self):
"""CompoundThermodynamicState.apply_to_context() method."""
"""Test CompoundThermodynamicState.apply_to_context() method."""
dummy_parameter = self.DummyState.standard_dummy_parameter
thermodynamic_state = ThermodynamicState(self.alanine_explicit, self.std_temperature)
thermodynamic_state.pressure = self.std_pressure
Expand All @@ -1090,6 +1190,22 @@ def test_method_apply_to_context(self):
assert context.getParameter('dummy_parameter') == self.dummy_parameter
assert context.getParameter(barostat.Pressure()) == new_pressure / unit.bar

def test_method_find_force_groups_to_update(self):
"""Test CompoundThermodynamicState._find_force_groups_to_update() method."""
alanine_explicit = copy.deepcopy(self.alanine_explicit)
thermodynamic_state = ThermodynamicState(alanine_explicit, self.std_temperature)
compound_state = CompoundThermodynamicState(thermodynamic_state, [self.dummy_state])
context = compound_state.create_context(openmm.VerletIntegrator(2.0*unit.femtoseconds))

# No force group should be updated if the two states are identical.
assert compound_state._find_force_groups_to_update(context, compound_state, memo={}) == set()

# If the dummy parameter changes, there should be 1 force group to update.
compound_state2 = copy.deepcopy(compound_state)
compound_state2.dummy_parameter -= 0.5
group = self.DummyState._find_max_force_group(context.getSystem())
assert compound_state._find_force_groups_to_update(context, compound_state2, memo={}) == {group}


# =============================================================================
# TEST SERIALIZATION
Expand Down
63 changes: 63 additions & 0 deletions openmmtools/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,76 @@ def test_math_eval():
# TEST QUANTITY UTILITIES
# =============================================================================

def test_tracked_quantity():
"""Test TrackedQuantity objects."""
def reset(q):
assert tracked_quantity.has_changed is True
tracked_quantity.has_changed = False

test_cases = [
np.array([10.0, 20.0, 30.0]) * unit.kelvin,
[1.0, 2.0, 3.0] * unit.nanometers,
]
for quantity in test_cases:
tracked_quantity = TrackedQuantity(quantity)
u = tracked_quantity.unit
assert tracked_quantity.has_changed is False

tracked_quantity[0] = 5.0 * u
assert tracked_quantity[0] == 5.0 * u
reset(tracked_quantity)

tracked_quantity[0:2] = [5.0, 6.0] * u
assert np.all(tracked_quantity[0:2] == [5.0, 6.0] * u)
reset(tracked_quantity)

if isinstance(tracked_quantity._value, list):
del tracked_quantity[0]
assert len(tracked_quantity) == 2
reset(tracked_quantity)

tracked_quantity.append(10.0*u)
assert len(tracked_quantity) == 3
reset(tracked_quantity)

tracked_quantity.extend([11.0, 12.0]*u)
assert len(tracked_quantity) == 5
reset(tracked_quantity)

element = 15.0*u
tracked_quantity.insert(1, element)
assert len(tracked_quantity) == 6
reset(tracked_quantity)

tracked_quantity.remove(element.value_in_unit(u))
assert len(tracked_quantity) == 5
reset(tracked_quantity)

assert tracked_quantity.pop().unit == u
assert len(tracked_quantity) == 4
reset(tracked_quantity)
else:
# Check that numpy views are handled correctly.
view = tracked_quantity[:3]
view[0] = 20.0*u
assert tracked_quantity[0] == 20.0*u
reset(tracked_quantity)

view2 = view[1:]
view2[0] = 30.0*u
assert tracked_quantity[1] == 30.0*u
reset(tracked_quantity)


def test_is_quantity_close():
"""Test is_quantity_close method."""
# (quantity1, quantity2, test_result)
test_cases = [(300.0*unit.kelvin, 300.000000004*unit.kelvin, True),
(300.0*unit.kelvin, 300.00000004*unit.kelvin, False),
(1.01325*unit.bar, 1.01325000006*unit.bar, True),
(1.01325*unit.bar, 1.0132500006*unit.bar, False)]

err_msg = 'obtained: {}, expected: {} (quantity1: {}, quantity2: {})'
for quantity1, quantity2, test_result in test_cases:
msg = "Test failed: ({}, {}, {})".format(quantity1, quantity2, test_result)
assert is_quantity_close(quantity1, quantity2) == test_result, msg
Expand Down
51 changes: 51 additions & 0 deletions openmmtools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,57 @@ def _math_eval(node):
# QUANTITY UTILITIES
# =============================================================================

def _changes_state(func):
"""Decorator to signal changes in TrackedQuantity."""
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
self.has_changed = True
return func(self, *args, **kwargs)
return wrapper


class TrackedQuantity(unit.Quantity):
"""A quantity that keeps track of whether it has been changed."""

def __init__(self, *args, **kwargs):
super(TrackedQuantity, self).__init__(*args, **kwargs)
self.has_changed = False

def __getitem__(self, item):
if isinstance(item, slice) and isinstance(self._value, np.ndarray):
return TrackedQuantityView(self, super(TrackedQuantity, self).__getitem__(item))
# No need to track a copy.
return super(TrackedQuantity, self).__getitem__(item)

__setitem__ = _changes_state(unit.Quantity.__setitem__)
__delitem__ = _changes_state(unit.Quantity.__delitem__)
append = _changes_state(unit.Quantity.append)
extend = _changes_state(unit.Quantity.extend)
insert = _changes_state(unit.Quantity.insert)
remove = _changes_state(unit.Quantity.remove)
pop = _changes_state(unit.Quantity.pop)


class TrackedQuantityView(unit.Quantity):
"""Keeps truck of a numpy view for TrackedQuantity."""

def __init__(self, tracked_quantity, *args, **kwargs):
super(TrackedQuantityView, self).__init__(*args, **kwargs)
self._tracked_quantity = tracked_quantity # Parent.

def __getitem__(self, item):
if isinstance(item, slice):
return TrackedQuantityView(self._tracked_quantity,
super(TrackedQuantityView, self).__getitem__(item))
# No need to track a copy.
return super(TrackedQuantityView, self).__getitem__(item)

def __setitem__(self, key, value):
super(TrackedQuantityView, self).__setitem__(key, value)
self._tracked_quantity.has_changed = True



# List of simtk.unit methods that are actually units and functions instead of base classes
# Pre-computed to reduce run-time cost
# Get the built-in units
Expand Down
Loading