Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an option to run PBC in single system mode #795

Merged
merged 5 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 43 additions & 6 deletions src/fairchem/core/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,12 @@ def generate_graph(
use_pbc=None,
otf_graph=None,
enforce_max_neighbors_strictly=None,
use_pbc_single=False,
):
cutoff = cutoff or self.cutoff
max_neighbors = max_neighbors or self.max_neighbors
use_pbc = use_pbc or self.use_pbc
use_pbc_single = use_pbc_single or self.use_pbc_single
otf_graph = otf_graph or self.otf_graph

if enforce_max_neighbors_strictly is not None:
Expand Down Expand Up @@ -84,12 +86,47 @@ def generate_graph(

if use_pbc:
if otf_graph:
edge_index, cell_offsets, neighbors = radius_graph_pbc(
data,
cutoff,
max_neighbors,
enforce_max_neighbors_strictly,
)
if use_pbc_single:
(
edge_index_per_system,
cell_offsets_per_system,
neighbors_per_system,
) = list(
zip(
*[
radius_graph_pbc(
data[idx],
cutoff,
max_neighbors,
enforce_max_neighbors_strictly,
)
for idx in range(len(data))
]
)
)

# atom indexs in the edge_index need to be offset
atom_index_offset = data.natoms.cumsum(dim=0).roll(1)
atom_index_offset[0] = 0
edge_index = torch.hstack(
[
edge_index_per_system[idx] + atom_index_offset[idx]
for idx in range(len(data))
]
)
cell_offsets = torch.vstack(cell_offsets_per_system)
neighbors = torch.hstack(neighbors_per_system)
else:
## TODO this is the original call, but blows up with memory
## using two different samples
## sid='mp-675045-mp-675045-0-7' (MPTRAJ)
## sid='75396' (OC22)
edge_index, cell_offsets, neighbors = radius_graph_pbc(
data,
cutoff,
max_neighbors,
enforce_max_neighbors_strictly,
)

out = get_pbc_distances(
data.pos,
Expand Down
36 changes: 16 additions & 20 deletions src/fairchem/core/models/dimenet_plus_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,16 +352,13 @@ def forward(
)
}
if self.regress_forces:
outputs["forces"] = (
-1
* (
torch.autograd.grad(
outputs["energy"],
data.pos,
grad_outputs=torch.ones_like(outputs["energy"]),
create_graph=True,
)[0]
)
outputs["forces"] = -1 * (
torch.autograd.grad(
outputs["energy"],
data.pos,
grad_outputs=torch.ones_like(outputs["energy"]),
create_graph=True,
)[0]
)
return outputs

Expand All @@ -371,6 +368,7 @@ class DimeNetPlusPlusWrap(DimeNetPlusPlus, GraphModelMixin):
def __init__(
self,
use_pbc: bool = True,
use_pbc_single: bool = False,
regress_forces: bool = True,
hidden_channels: int = 128,
num_blocks: int = 4,
Expand All @@ -388,6 +386,7 @@ def __init__(
) -> None:
self.regress_forces = regress_forces
self.use_pbc = use_pbc
self.use_pbc_single = use_pbc_single
self.cutoff = cutoff
self.otf_graph = otf_graph
self.max_neighbors = 50
Expand Down Expand Up @@ -466,16 +465,13 @@ def forward(self, data):
outputs = {"energy": energy}

if self.regress_forces:
forces = (
-1
* (
torch.autograd.grad(
energy,
data.pos,
grad_outputs=torch.ones_like(energy),
create_graph=True,
)[0]
)
forces = -1 * (
torch.autograd.grad(
energy,
data.pos,
grad_outputs=torch.ones_like(energy),
create_graph=True,
)[0]
)
outputs["forces"] = forces

Expand Down
2 changes: 2 additions & 0 deletions src/fairchem/core/models/equiformer_v2/equiformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class EquiformerV2(nn.Module, GraphModelMixin):
def __init__(
self,
use_pbc: bool = True,
use_pbc_single: bool = False,
regress_forces: bool = True,
otf_graph: bool = True,
max_neighbors: int = 500,
Expand Down Expand Up @@ -169,6 +170,7 @@ def __init__(
raise ImportError

self.use_pbc = use_pbc
self.use_pbc_single = use_pbc_single
self.regress_forces = regress_forces
self.otf_graph = otf_graph
self.max_neighbors = max_neighbors
Expand Down
3 changes: 3 additions & 0 deletions src/fairchem/core/models/escn/escn.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class eSCN(nn.Module, GraphModelMixin):

Args:
use_pbc (bool): Use periodic boundary conditions
use_pbc_single (bool): Process batch PBC graphs one at a time
regress_forces (bool): Compute forces
otf_graph (bool): Compute graph On The Fly (OTF)
max_neighbors (int): Maximum number of neighbors per atom
Expand All @@ -69,6 +70,7 @@ class eSCN(nn.Module, GraphModelMixin):
def __init__(
self,
use_pbc: bool = True,
use_pbc_single: bool = False,
regress_forces: bool = True,
otf_graph: bool = False,
max_neighbors: int = 40,
Expand Down Expand Up @@ -100,6 +102,7 @@ def __init__(

self.regress_forces = regress_forces
self.use_pbc = use_pbc
self.use_pbc_single = use_pbc_single
self.cutoff = cutoff
self.otf_graph = otf_graph
self.show_timing_info = show_timing_info
Expand Down
2 changes: 2 additions & 0 deletions src/fairchem/core/models/gemnet/gemnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
extensive: bool = True,
otf_graph: bool = False,
use_pbc: bool = True,
use_pbc_single: bool = False,
output_init: str = "HeOrthogonal",
activation: str = "swish",
num_elements: int = 83,
Expand All @@ -143,6 +144,7 @@ def __init__(
self.regress_forces = regress_forces
self.otf_graph = otf_graph
self.use_pbc = use_pbc
self.use_pbc_single = use_pbc_single

# GemNet variants
self.direct_forces = direct_forces
Expand Down
2 changes: 2 additions & 0 deletions src/fairchem/core/models/gemnet_gp/gemnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(
extensive: bool = True,
otf_graph: bool = False,
use_pbc: bool = True,
use_pbc_single: bool = False,
output_init: str = "HeOrthogonal",
activation: str = "swish",
scale_num_blocks: bool = False,
Expand Down Expand Up @@ -142,6 +143,7 @@ def __init__(
self.regress_forces = regress_forces
self.otf_graph = otf_graph
self.use_pbc = use_pbc
self.use_pbc_single = use_pbc_single

# GemNet variants
self.direct_forces = direct_forces
Expand Down
4 changes: 4 additions & 0 deletions src/fairchem/core/models/gemnet_oc/gemnet_oc.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ class GemNetOC(nn.Module, GraphModelMixin):
If False predict forces based on negative gradient of energy potential.
use_pbc: bool
Whether to use periodic boundary conditions.
use_pbc_single:
Process batch PBC graphs one at a time
scale_backprop_forces: bool
Whether to scale up the energy and then scales down the forces
to prevent NaNs and infs in backpropagated forces.
Expand Down Expand Up @@ -203,6 +205,7 @@ def __init__(
regress_forces: bool = True,
direct_forces: bool = False,
use_pbc: bool = True,
use_pbc_single: bool = False,
scale_backprop_forces: bool = False,
cutoff: float = 6.0,
cutoff_qint: float | None = None,
Expand Down Expand Up @@ -269,6 +272,7 @@ def __init__(
)
self.enforce_max_neighbors_strictly = enforce_max_neighbors_strictly
self.use_pbc = use_pbc
self.use_pbc_single = use_pbc_single

self.direct_forces = direct_forces
self.forces_coupled = forces_coupled
Expand Down
2 changes: 2 additions & 0 deletions src/fairchem/core/models/painn/painn.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
regress_forces: bool = True,
direct_forces: bool = True,
use_pbc: bool = True,
use_pbc_single: bool = False,
otf_graph: bool = True,
num_elements: int = 83,
scale_file: str | None = None,
Expand All @@ -92,6 +93,7 @@ def __init__(
self.direct_forces = direct_forces
self.otf_graph = otf_graph
self.use_pbc = use_pbc
self.use_pbc_single = use_pbc_single

# Borrowed from GemNet.
self.symmetric_edge_symmetrization = False
Expand Down
20 changes: 10 additions & 10 deletions src/fairchem/core/models/schnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class SchNetWrap(SchNet, GraphModelMixin):
Args:
use_pbc (bool, optional): If set to :obj:`True`, account for periodic boundary conditions.
(default: :obj:`True`)
use_pbc_single (bool,optional): Process batch PBC graphs one at a time
regress_forces (bool, optional): If set to :obj:`True`, predict forces by differentiating
energy with respect to positions.
(default: :obj:`True`)
Expand All @@ -52,6 +53,7 @@ class SchNetWrap(SchNet, GraphModelMixin):
def __init__(
self,
use_pbc: bool = True,
use_pbc_single: bool = False,
regress_forces: bool = True,
otf_graph: bool = False,
hidden_channels: int = 128,
Expand All @@ -64,6 +66,7 @@ def __init__(
self.num_targets = 1
self.regress_forces = regress_forces
self.use_pbc = use_pbc
self.use_pbc_single = use_pbc_single
self.cutoff = cutoff
self.otf_graph = otf_graph
self.max_neighbors = 50
Expand Down Expand Up @@ -111,16 +114,13 @@ def forward(self, data):
outputs = {"energy": energy}

if self.regress_forces:
forces = (
-1
* (
torch.autograd.grad(
energy,
data.pos,
grad_outputs=torch.ones_like(energy),
create_graph=True,
)[0]
)
forces = -1 * (
torch.autograd.grad(
energy,
data.pos,
grad_outputs=torch.ones_like(energy),
create_graph=True,
)[0]
)
outputs["forces"] = forces

Expand Down
3 changes: 3 additions & 0 deletions src/fairchem/core/models/scn/scn.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class SphericalChannelNetwork(nn.Module, GraphModelMixin):

Args:
use_pbc (bool): Use periodic boundary conditions
use_pbc_single (bool): Process batch PBC graphs one at a time
regress_forces (bool): Compute forces
otf_graph (bool): Compute graph On The Fly (OTF)
max_num_neighbors (int): Maximum number of neighbors per atom
Expand Down Expand Up @@ -76,6 +77,7 @@ class SphericalChannelNetwork(nn.Module, GraphModelMixin):
def __init__(
self,
use_pbc: bool = True,
use_pbc_single: bool = True,
regress_forces: bool = True,
otf_graph: bool = False,
max_num_neighbors: int = 20,
Expand Down Expand Up @@ -107,6 +109,7 @@ def __init__(

self.regress_forces = regress_forces
self.use_pbc = use_pbc
self.use_pbc_single = use_pbc_single
self.cutoff = cutoff
self.otf_graph = otf_graph
self.show_timing_info = show_timing_info
Expand Down
19 changes: 19 additions & 0 deletions tests/core/e2e/test_s2ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,25 @@ def test_train_and_predict(
otf_norms=otf_norms,
)

def test_use_pbc_single(self, configs, tutorial_val_src, torch_deterministic):
with tempfile.TemporaryDirectory() as tempdirname:
tempdir = Path(tempdirname)
extra_args = {"seed": 0}
_ = _run_main(
rundir=str(tempdir),
update_dict_with={
"optim": {"max_epochs": 1},
"model": {"use_pbc_single": True},
"dataset": oc20_lmdb_train_and_val_from_paths(
train_src=str(tutorial_val_src),
val_src=str(tutorial_val_src),
test_src=str(tutorial_val_src),
),
},
update_run_args_with=extra_args,
input_yaml=configs["equiformer_v2"],
)

@pytest.mark.parametrize(
("world_size", "ddp"),
[
Expand Down