Skip to content

Commit

Permalink
Add initial re-weighting support (#86)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd authored Oct 29, 2023
1 parent d453bef commit 732f2f0
Show file tree
Hide file tree
Showing 8 changed files with 413 additions and 24 deletions.
1 change: 1 addition & 0 deletions devtools/envs/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies:

# Dev / Testing
- ambertools
- scipy # test logsumexp implementation

- versioneer

Expand Down
3 changes: 2 additions & 1 deletion smee/mm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

from smee.mm._config import GenerateCoordsConfig, MinimizationConfig, SimulationConfig
from smee.mm._mm import generate_system_coords, simulate
from smee.mm._ops import compute_ensemble_averages
from smee.mm._ops import compute_ensemble_averages, reweight_ensemble_averages
from smee.mm._reporters import TensorReporter, unpack_frames

__all__ = [
"compute_ensemble_averages",
"generate_system_coords",
"reweight_ensemble_averages",
"simulate",
"GenerateCoordsConfig",
"MinimizationConfig",
Expand Down
187 changes: 179 additions & 8 deletions smee/mm/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
"""Convert from g / mol / Å**3 to g / mL"""


class NotEnoughSamplesError(ValueError):
"""An error raised when an ensemble average is attempted with too few samples."""


class _EnsembleAverageKwargs(typing.TypedDict):
"""The keyword arguments passed to the custom PyTorch op for computing ensemble
averages."""
Expand All @@ -32,6 +36,13 @@ class _EnsembleAverageKwargs(typing.TypedDict):
pressure: float | None


class _ReweightAverageKwargs(_EnsembleAverageKwargs):
"""The keyword arguments passed to the custom PyTorch op for computing re-weighted
ensemble averages."""

min_samples: int


def _pack_force_field(
force_field: smee.TensorForceField,
) -> tuple[tuple[torch.Tensor, ...], dict[str, int], dict[str, int]]:
Expand Down Expand Up @@ -127,6 +138,7 @@ def _compute_frame_observables(
box_vectors: torch.Tensor,
potential_energy: float,
kinetic_energy: float,
beta: float,
pressure: float | None,
) -> dict[str, float]:
"""Compute observables for a given frame in a trajectory.
Expand All @@ -140,15 +152,18 @@ def _compute_frame_observables(
box_vectors: The box vectors [Å] of this frame.
potential_energy: The potential energy [kcal / mol] of this frame.
kinetic_energy: The kinetic energy [kcal / mol] of this frame.
beta: The inverse temperature [mol / kcal].
pressure: The pressure [kcal / mol / Å^3] if NPT.
Returns:
The observables for this frame.
"""

values = {"potential_energy": potential_energy}
reduced_potential = beta * potential_energy

if not system.is_periodic:
values["reduced_potential"] = reduced_potential
return values

volume = torch.det(box_vectors)
Expand All @@ -162,6 +177,10 @@ def _compute_frame_observables(
pv_term = volume * pressure
values["enthalpy"] = potential_energy + kinetic_energy + pv_term

reduced_potential += beta * pv_term

values["reduced_potential"] = reduced_potential

return values


Expand All @@ -170,8 +189,9 @@ def _compute_observables(
force_field: smee.TensorForceField,
frames_file: typing.BinaryIO,
theta: tuple[torch.Tensor],
beta: float,
pressure: float | None = None,
) -> tuple[torch.Tensor, list[str], list[torch.Tensor]]:
) -> tuple[torch.Tensor, list[str], torch.Tensor, list[torch.Tensor]]:
"""Computes the standard set of 'observables', and the gradient of the potential
energy with respect to ``theta`` over a given trajectory.
Expand All @@ -182,24 +202,30 @@ def _compute_observables(
Args:
system: The system that was simulated.
force_field: The force field used to simulate the trajectory.
force_field: The force field to evaluate energies with.
frames_file: The file containing the trajectory.
theta: The parameters to compute the gradient with respect to.
beta: The inverse temperature [mol / kcal].
pressure: The pressure [kcal / mol / Å^3] if NPT.
Returns:
The observables at each frame, the columns of the observable tensor, and the
gradients of the potential energy with respect to each tensor in theta with
The observables at each frame, the columns of the observable tensor, the
reduced potential energy at each frame, and the gradients of the potential
energy with respect to each tensor in theta with
``shape=(n_parameters, n_parameter_cols)``.
"""

needs_grad = [i for i, v in enumerate(theta) if v is not None and v.requires_grad]
du_d_theta = [None if i not in needs_grad else [] for i in range(len(theta))]

reduced_potentials = []

values = []
columns = None

for coords, box_vectors, kinetic in smee.mm._reporters.unpack_frames(frames_file):
for coords, box_vectors, _, kinetic in smee.mm._reporters.unpack_frames(
frames_file
):
coords = coords.to(theta[0].device)
box_vectors = box_vectors.to(theta[0].device)

Expand All @@ -218,19 +244,23 @@ def _compute_observables(
du_d_theta[i].append(du_d_theta_subset[idx].float())

frame = _compute_frame_observables(
system, box_vectors, potential.detach(), kinetic, pressure
system, box_vectors, potential.detach(), kinetic, beta, pressure
)

reduced_potentials.append(frame.pop("reduced_potential"))

if columns is None:
columns = [*frame]

values.append(torch.tensor([frame[c] for c in columns]))

values = torch.stack(values).to(theta[0].device)
reduced_potentials = smee.utils.tensor_like(reduced_potentials, theta[0])

return (
values,
columns,
reduced_potentials,
[v if v is None else torch.stack(v, dim=-1) for v in du_d_theta],
)

Expand All @@ -249,8 +279,8 @@ def forward(ctx, kwargs: _EnsembleAverageKwargs, *theta: torch.Tensor):
system = kwargs["system"]

with kwargs["frames_path"].open("rb") as file:
values, columns, du_d_theta = _compute_observables(
system, force_field, file, theta, pressure=kwargs["pressure"]
values, columns, _, du_d_theta = _compute_observables(
system, force_field, file, theta, kwargs["beta"], kwargs["pressure"]
)

avg_values = values.mean(dim=0)
Expand Down Expand Up @@ -303,6 +333,95 @@ def backward(ctx, *grad_outputs):
return tuple([None] + grads + [None])


class _ReweightAverageOp(torch.autograd.Function):
"""A custom PyTorch op for computing ensemble averages over MD trajectories."""

@staticmethod
def forward(ctx, kwargs: _ReweightAverageKwargs, *theta: torch.Tensor):
force_field = _unpack_force_field(
theta,
kwargs["parameter_lookup"],
kwargs["attribute_lookup"],
kwargs["force_field"],
)
system = kwargs["system"]

with kwargs["frames_path"].open("rb") as file:
values, columns, reduced_pot, du_d_theta = _compute_observables(
system, force_field, file, theta, kwargs["beta"], kwargs["pressure"]
)

with kwargs["frames_path"].open("rb") as file:
reduced_pot_0 = smee.utils.tensor_like(
[v for _, _, v, _ in smee.mm._reporters.unpack_frames(file)],
reduced_pot,
)

delta = (reduced_pot_0 - reduced_pot).double()

ln_weights = delta - torch.logsumexp(delta, dim=0)
weights = torch.exp(ln_weights)

n_effective = torch.exp(-torch.sum(weights * ln_weights, dim=0))

if n_effective < kwargs["min_samples"]:
raise NotEnoughSamplesError

avg_values = (weights[:, None] * values).sum(dim=0)

ctx.beta = kwargs["beta"]
ctx.n_theta = len(theta)
ctx.columns = columns
ctx.save_for_backward(*theta, *du_d_theta, delta, weights, values)

return tuple([*avg_values, columns])

@staticmethod
def backward(ctx, *grad_outputs):
theta = ctx.saved_tensors[: ctx.n_theta]

du_d_theta = ctx.saved_tensors[ctx.n_theta : 2 * ctx.n_theta]
d_reduced_d_theta = [ctx.beta * du for du in du_d_theta if du is not None]

values = ctx.saved_tensors[-1]
weights = ctx.saved_tensors[-2]
delta = ctx.saved_tensors[-3]

grads = [None] * len(theta)

for i in range(len(du_d_theta)):
if du_d_theta[i] is None:
continue

avg_d_reduced_d_theta_i = torch.exp(
smee.utils.logsumexp(delta[None, None, :], -1, b=d_reduced_d_theta[i])
- torch.logsumexp(delta, 0)
)

d_ln_weight_d_theta_i = (
-d_reduced_d_theta[i] + avg_d_reduced_d_theta_i[:, :, None]
)
d_weight_d_theta_i = weights[None, None, :] * d_ln_weight_d_theta_i

d_output_d_theta_i = {
"potential_energy": du_d_theta[i],
"volume": torch.zeros_like(du_d_theta[i]),
"density": torch.zeros_like(du_d_theta[i]),
"enthalpy": du_d_theta[i],
}
d_output_d_theta_i = torch.stack(
[d_output_d_theta_i[column] for column in ctx.columns], dim=-1
)

grads[i] = (
d_weight_d_theta_i[:, :, :, None] * values[None, None, :, :]
+ weights[None, None, :, None] * d_output_d_theta_i
).sum(-2) @ torch.stack(grad_outputs[:-1])

# we need to return one extra 'gradient' for kwargs.
return tuple([None] + grads + [None])


def compute_ensemble_averages(
system: smee.TensorSystem,
force_field: smee.TensorForceField,
Expand Down Expand Up @@ -346,3 +465,55 @@ def compute_ensemble_averages(

*avg_outputs, columns = _EnsembleAverageOp.apply(kwargs, *tensors)
return {column: avg for avg, column in zip(avg_outputs, columns)}


def reweight_ensemble_averages(
system: smee.TensorSystem,
force_field: smee.TensorForceField,
frames_path: pathlib.Path,
temperature: openmm.unit.Quantity,
pressure: openmm.unit.Quantity | None,
min_samples: int = 50,
) -> dict[str, torch.Tensor]:
"""Compute the ensemble average of the potential energy, volume, density,
and enthalpy (if running NPT) by re-weighting an existing MD trajectory.
Args:
system: The system that was simulated.
force_field: The new force field to use.
frames_path: The path to the trajectory to compute the average over.
temperature: The temperature that the trajectory was simulated at.
pressure: The pressure that the trajectory was simulated at.
min_samples: The minimum number of samples required to compute the average.
Raises:
NotEnoughSamplesError: If the number of effective samples is less than
``min_samples``.
Returns:
A dictionary containing the ensemble averages of the potential energy
[kcal/mol], volume [Å^3], density [g/mL], and enthalpy [kcal/mol].
"""
tensors, parameter_lookup, attribute_lookup = _pack_force_field(force_field)

beta = 1.0 / (openmm.unit.MOLAR_GAS_CONSTANT_R * temperature)
beta = beta.value_in_unit(openmm.unit.kilocalorie_per_mole**-1)

if pressure is not None:
pressure = (pressure * openmm.unit.AVOGADRO_CONSTANT_NA).value_in_unit(
openmm.unit.kilocalorie_per_mole / openmm.unit.angstrom**3
)

kwargs: _ReweightAverageKwargs = {
"force_field": force_field,
"parameter_lookup": parameter_lookup,
"attribute_lookup": attribute_lookup,
"system": system,
"frames_path": frames_path,
"beta": beta,
"pressure": pressure,
"min_samples": min_samples,
}

*avg_outputs, columns = _ReweightAverageOp.apply(kwargs, *tensors)
return {column: avg for avg, column in zip(avg_outputs, columns)}
30 changes: 28 additions & 2 deletions smee/mm/_reporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,30 @@ def _decoder(obj, chain=None):
class TensorReporter:
"""A reporter which stores coords, box vectors, and kinetic energy using msgpack."""

def __init__(self, output_file: typing.BinaryIO, report_interval: int):
def __init__(
self,
output_file: typing.BinaryIO,
report_interval: int,
beta: openmm.unit.Quantity,
pressure: openmm.unit.Quantity | None,
):
"""
Args:
output_file: The file to write the frames to.
report_interval: The interval (in steps) at which to write frames.
beta: The inverse temperature the simulation is being run at.
pressure: The pressure the simulation is being run at, or None if NVT /
vacuum.
"""
self._output_file = output_file
self._report_interval = report_interval

self._beta = beta
self._pressure = (
None if pressure is None else pressure * openmm.unit.AVOGADRO_CONSTANT_NA
)

def describeNextReport(self, simulation: openmm.app.Simulation):
steps = self._report_interval - simulation.currentStep % self._report_interval
# requires - positions, velocities, forces, energies?
Expand All @@ -64,12 +78,24 @@ def report(self, simulation: openmm.app.Simulation, state: openmm.State):
if math.isinf(total_energy.value_in_unit(_KCAL_PER_MOL)):
raise ValueError("total energy is infinite")

unreduced_potential = potential_energy

if self._pressure is not None:
unreduced_potential += self._pressure * state.getPeriodicBoxVolume()

reduced_potential = unreduced_potential * self._beta

coords = state.getPositions(asNumpy=True).value_in_unit(_ANGSTROM)
coords = torch.from_numpy(coords).float()
box_vectors = state.getPeriodicBoxVectors(asNumpy=True).value_in_unit(_ANGSTROM)
box_vectors = torch.from_numpy(box_vectors).float()

frame = (coords, box_vectors, kinetic_energy.value_in_unit(_KCAL_PER_MOL))
frame = (
coords,
box_vectors,
reduced_potential,
kinetic_energy.value_in_unit(_KCAL_PER_MOL),
)
self._output_file.write(msgpack.dumps(frame, default=_encoder))


Expand Down
Loading

0 comments on commit 732f2f0

Please sign in to comment.