Skip to content

Commit

Permalink
Add test for PME exclusions
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd committed Oct 16, 2023
1 parent 466bff3 commit 4184b6c
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 140 deletions.
107 changes: 22 additions & 85 deletions smee/tests/potentials/conftest.py
Original file line number Diff line number Diff line change
@@ -1,97 +1,34 @@
import openff.interchange.models
import openff.units
import copy

import openmm.unit
import pytest
import torch

import smee.potentials
import smee.mm
import smee.tests.utils
import smee.utils


@pytest.fixture()
def mock_lj_potential() -> smee.TensorPotential:
return smee.TensorPotential(
type="vdW",
fn="LJ",
parameters=torch.tensor([[0.1, 1.1], [0.2, 2.1], [0.3, 3.1]]),
parameter_keys=[
openff.interchange.models.PotentialKey(id="[#1:1]"),
openff.interchange.models.PotentialKey(id="[#6:1]"),
openff.interchange.models.PotentialKey(id="[#8:1]"),
],
parameter_cols=("epsilon", "sigma"),
parameter_units=(
openff.units.unit.kilojoule_per_mole,
openff.units.unit.angstrom,
),
attributes=torch.tensor([0.0, 0.0, 0.5, 1.0, 9.0, 2.0]),
attribute_cols=(
"scale_12",
"scale_13",
"scale_14",
"scale_15",
"cutoff",
"switch_width",
),
attribute_units=(
openff.units.unit.dimensionless,
openff.units.unit.dimensionless,
openff.units.unit.dimensionless,
openff.units.unit.dimensionless,
openff.units.unit.angstrom,
openff.units.unit.angstrom,
@pytest.fixture(scope="module")
def _etoh_water_system() -> (
tuple[smee.TensorSystem, smee.TensorForceField, torch.Tensor, torch.Tensor]
):
system, force_field = smee.tests.utils.system_from_smiles(["CCO", "O"], [67, 123])
coords, box_vectors = smee.mm.generate_system_coords(system)

return (
system,
force_field,
torch.tensor(coords.value_in_unit(openmm.unit.angstrom), dtype=torch.float32),
torch.tensor(
box_vectors.value_in_unit(openmm.unit.angstrom), dtype=torch.float32
),
)


@pytest.fixture()
def mock_methane_top() -> smee.TensorTopology:
methane_top = smee.tests.utils.topology_from_smiles("C")
methane_top.parameters = {
"vdW": smee.NonbondedParameterMap(
assignment_matrix=torch.tensor(
[
[0.0, 1.0, 0.0],
[1.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
]
).to_sparse(),
exclusions=torch.tensor(
[
[0, 1],
[0, 2],
[0, 3],
[0, 4],
[1, 2],
[1, 3],
[1, 4],
[2, 3],
[2, 4],
[3, 4],
]
),
exclusion_scale_idxs=torch.tensor([[0] * 4 + [1] * 6]),
)
}
return methane_top

def etoh_water_system(
_etoh_water_system,
) -> tuple[smee.TensorSystem, smee.TensorForceField, torch.Tensor, torch.Tensor]:
"""Creates a system of ethanol and water."""

@pytest.fixture()
def mock_water_top() -> smee.TensorTopology:
methane_top = smee.tests.utils.topology_from_smiles("O")
methane_top.parameters = {
"vdW": smee.NonbondedParameterMap(
assignment_matrix=torch.tensor(
[
[0.0, 0.0, 1.0],
[1.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
]
).to_sparse(),
exclusions=torch.tensor([[0, 1], [0, 2], [1, 2]]),
exclusion_scale_idxs=torch.tensor([[0], [0], [1]]),
)
}
return methane_top
return copy.deepcopy(_etoh_water_system)
104 changes: 59 additions & 45 deletions smee/tests/potentials/test_nonbonded.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import numpy
import openff.interchange
import openff.toolkit
import openmm.unit
import torch

import smee.converters
import smee.converters.openmm
import smee.mm
import smee.tests.utils
from smee.potentials.nonbonded import (
_COULOMB_PRE_FACTOR,
_compute_coulomb_energy_periodic,
_compute_lj_energy_periodic,
_compute_pair_scales,
_compute_pme_exclusions,
compute_coulomb_energy,
compute_lj_energy,
)
Expand All @@ -23,6 +23,9 @@ def compute_openmm_periodic_energy(
box_vectors: openmm.unit.Quantity,
potential: smee.TensorPotential,
) -> torch.Tensor:
coords = coords.numpy() * openmm.unit.angstrom
box_vectors = box_vectors.numpy() * openmm.unit.angstrom

omm_force = smee.converters.convert_to_openmm_force(potential, system)

omm_system = smee.converters.openmm.create_openmm_system(system)
Expand All @@ -48,6 +51,45 @@ def test_coulomb_pre_factor():
assert numpy.isclose(_COULOMB_PRE_FACTOR * _KCAL_TO_KJ, 1389.3545764, atol=1.0e-7)


def test_compute_pme_exclusions():
system, force_field = smee.tests.utils.system_from_smiles(["C", "O"], [2, 3])

coulomb_potential = force_field.potentials_by_type["Electrostatics"]
exclusions = _compute_pme_exclusions(system, coulomb_potential)

expected_exclusions = torch.tensor(
[
# C #1
[1, 2, 3, 4],
[0, 2, 3, 4],
[0, 1, 3, 4],
[0, 1, 2, 4],
[0, 1, 2, 3],
# C #2
[6, 7, 8, 9],
[5, 7, 8, 9],
[5, 6, 8, 9],
[5, 6, 7, 9],
[5, 6, 7, 8],
# O #1
[11, 12, -1, -1],
[10, 12, -1, -1],
[10, 11, -1, -1],
# O #2
[14, 15, -1, -1],
[13, 15, -1, -1],
[13, 14, -1, -1],
# O #3
[17, 18, -1, -1],
[16, 18, -1, -1],
[16, 17, -1, -1],
]
)

assert exclusions.shape == expected_exclusions.shape
assert torch.allclose(exclusions, expected_exclusions)


def test_compute_coulomb_energy_two_particle():
scale_factor = 5.0

Expand Down Expand Up @@ -93,30 +135,14 @@ def test_compute_coulomb_energy_three_particle():
assert torch.isclose(expected_energy, actual_energy)


def test_compute_coulomb_energy_periodic():
interchanges = [
openff.interchange.Interchange.from_smirnoff(
openff.toolkit.ForceField("openff-2.0.0.offxml"),
openff.toolkit.Molecule.from_smiles(smiles).to_topology(),
)
for smiles in ["CCO", "O"]
]

tensor_ff, tensor_tops = smee.converters.convert_interchange(interchanges)
tensor_sys = smee.TensorSystem(tensor_tops, [67, 123], is_periodic=True)
def test_compute_coulomb_energy_periodic(etoh_water_system):
tensor_sys, tensor_ff, coords, box_vectors = etoh_water_system

coulomb_potential = tensor_ff.potentials_by_type["Electrostatics"]
coulomb_potential.parameters.requires_grad = True

coords, box_vectors = smee.mm.generate_system_coords(tensor_sys)

energy = _compute_coulomb_energy_periodic(
tensor_sys,
torch.tensor(coords.value_in_unit(openmm.unit.angstrom), dtype=torch.float32),
torch.tensor(
box_vectors.value_in_unit(openmm.unit.angstrom), dtype=torch.float32
),
coulomb_potential,
tensor_sys, coords.float(), box_vectors.float(), coulomb_potential
)
energy.backward()

Expand Down Expand Up @@ -192,12 +218,15 @@ def test_compute_lj_energy_three_particle():
assert torch.isclose(expected_energy, actual_energy)


def test_compute_pair_scales(mock_lj_potential, mock_methane_top, mock_water_top):
mock_lj_potential.attributes = torch.tensor([0.01, 0.02, 0.5, 1.0, 9.0, 2.0])
def test_compute_pair_scales():
system, force_field = smee.tests.utils.system_from_smiles(["C", "O"], [2, 3])

system = smee.TensorSystem([mock_methane_top, mock_water_top], [2, 3], True)
vdw_potential = force_field.potentials_by_type["vdW"]
vdw_potential.attributes = torch.tensor(
[0.01, 0.02, 0.5, 1.0, 9.0, 2.0], dtype=torch.float64
)

scales = _compute_pair_scales(system, mock_lj_potential)
scales = _compute_pair_scales(system, vdw_potential)

# fmt: off
expected_scale_matrix = torch.tensor(
Expand Down Expand Up @@ -225,7 +254,8 @@ def test_compute_pair_scales(mock_lj_potential, mock_methane_top, mock_water_top
[1.0] * 16 + [1.0, 0.01, 0.01],
[1.0] * 16 + [0.01, 1.0, 0.02],
[1.0] * 16 + [0.01, 0.02, 1.0],
]
],
dtype=torch.float64
)
# fmt: on

Expand All @@ -236,30 +266,14 @@ def test_compute_pair_scales(mock_lj_potential, mock_methane_top, mock_water_top
assert torch.allclose(scales, expected_scales)


def test_compute_lj_energy_periodic():
interchanges = [
openff.interchange.Interchange.from_smirnoff(
openff.toolkit.ForceField("openff-2.0.0.offxml"),
openff.toolkit.Molecule.from_smiles(smiles).to_topology(),
)
for smiles in ["CCO", "O"]
]

tensor_ff, tensor_tops = smee.converters.convert_interchange(interchanges)
tensor_sys = smee.TensorSystem(tensor_tops, [67, 123], is_periodic=True)
def test_compute_lj_energy_periodic(etoh_water_system):
tensor_sys, tensor_ff, coords, box_vectors = etoh_water_system

vdw_potential = tensor_ff.potentials_by_type["vdW"]
vdw_potential.parameters.requires_grad = True

coords, box_vectors = smee.mm.generate_system_coords(tensor_sys)

energy = _compute_lj_energy_periodic(
tensor_sys,
torch.tensor(coords.value_in_unit(openmm.unit.angstrom), dtype=torch.float32),
torch.tensor(
box_vectors.value_in_unit(openmm.unit.angstrom), dtype=torch.float32
),
vdw_potential,
tensor_sys, coords.float(), box_vectors.float(), vdw_potential
)
energy.backward()

Expand Down
21 changes: 11 additions & 10 deletions smee/tests/potentials/test_potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,23 @@
from smee.potentials import broadcast_parameters, compute_energy


def test_broadcast_parameters(mock_lj_potential, mock_methane_top, mock_water_top):
system = smee.TensorSystem([mock_methane_top, mock_water_top], [2, 3], True)
def test_broadcast_parameters():
system, force_field = smee.tests.utils.system_from_smiles(["C", "O"], [2, 3])
vdw_potential = force_field.potentials_by_type["vdW"]

parameters = broadcast_parameters(system, mock_lj_potential)
methane_top, water_top = system.topologies

methane_parameters = (
mock_methane_top.parameters["vdW"].assignment_matrix
@ mock_lj_potential.parameters
parameters = broadcast_parameters(system, vdw_potential)

expected_methane_parameters = (
methane_top.parameters["vdW"].assignment_matrix @ vdw_potential.parameters
)
water_parameters = (
mock_water_top.parameters["vdW"].assignment_matrix
@ mock_lj_potential.parameters
expected_water_parameters = (
water_top.parameters["vdW"].assignment_matrix @ vdw_potential.parameters
)

expected_parameters = torch.vstack(
[methane_parameters] * 2 + [water_parameters] * 3
[expected_methane_parameters] * 2 + [expected_water_parameters] * 3
)
assert parameters.shape == expected_parameters.shape
assert torch.allclose(parameters, expected_parameters)
Expand Down
30 changes: 30 additions & 0 deletions smee/tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import openff.interchange
import openff.toolkit
import torch
from rdkit import Chem

import smee
import smee.converters


def topology_from_smiles(smiles: str) -> smee.TensorTopology:
Expand Down Expand Up @@ -30,3 +33,30 @@ def topology_from_smiles(smiles: str) -> smee.TensorTopology:
v_sites=None,
constraints=None,
)


def system_from_smiles(
smiles: list[str], n_copies: list[int]
) -> tuple[smee.TensorSystem, smee.TensorForceField]:
"""Creates a system from a list of SMILES strings.
Args:
smiles: The list of SMILES strings.
n_copies: The number of copies of each molecule.
Returns:
The system and force field.
"""
force_field = openff.toolkit.ForceField("openff-2.0.0.offxml")

interchanges = [
openff.interchange.Interchange.from_smirnoff(
force_field,
openff.toolkit.Molecule.from_smiles(pattern).to_topology(),
)
for pattern in smiles
]

tensor_ff, tensor_tops = smee.converters.convert_interchange(interchanges)

return smee.TensorSystem(tensor_tops, n_copies, is_periodic=True), tensor_ff

0 comments on commit 4184b6c

Please sign in to comment.