Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: pairtab model pytorch #174

Closed
wants to merge 18 commits into from
Prev Previous commit
Next Next commit
fix: parallelize nframes
Anyang Peng authored and Anyang Peng committed Jan 23, 2024
commit 180f71903b4071a02dd62425184c8b78627acdd6
77 changes: 42 additions & 35 deletions deepmd_pt/model/model/pair_tab.py
Original file line number Diff line number Diff line change
@@ -18,13 +18,14 @@
PairTab,
)
import torch
from torch import nn
import numpy as np
from typing import Dict, List, Optional, Union

from deepmd_utils.model_format import FittingOutputDef, OutputVariableDef
from deepmd_pt.model.task import Fitting

class PairTabModel(AtomicModel):
class PairTabModel(nn.Module, AtomicModel):
"""Pairwise tabulation energy model.

This model can be used to tabulate the pairwise energy between atoms for either
@@ -107,44 +108,47 @@ def forward_atomic(


nframes, nloc, nnei = nlist.shape
atype = extended_atype[:, :nloc] #this is the atype for local atoms, nframes * nloc
pairwise_dr = self._get_pairwise_dist(extended_coord)
atype = extended_atype[:, :nloc] #this is the atype for local atoms, (nframes, nloc)
pairwise_dr = self._get_pairwise_dist(extended_coord) # (nframes, nall, nall, 3)
pairwise_rr = pairwise_dr.pow(2).sum(-1).sqrt() # (nframes, nall, nall), this is the pairwise scalar distance for all atoms in all frames.

self.tab_data = self.tab_data.reshape(self.tab.ntypes,self.tab.ntypes,self.tab.nspline,4)

atomic_energy = torch.zeros(nloc)
atomic_energy = torch.zeros(nframes, nloc)

for atom_i in range(nloc):
i_type = atype[:,i_type] # not quite sure about this on frame dimension
i_type = atype[:,atom_i] # (nframes, 1)
for atom_j in range(nnei):
j_idx = nnei[atom_j]
j_type = extended_atype[:,j_idx] #same here
rr = pairwise_dr[atom_i][atom_j].pow(2).sum().sqrt()

# need to handle i_type and j_type frame dimension
pairwise_ene = self._pair_tabulated_inter(i_type,j_type,rr)
atomic_energy[atom_i] += pairwise_ene
j_type = extended_atype[:,j_idx] # (nframes, 1)

rr = pairwise_rr[:, atom_i, j_idx] # (nframes, 1)

# the input shape is (nframes, 1), (nframes, 1), (nframes, 1),
# the expected output shape then becomes (nframes,1)
pairwise_ene = self._pair_tabulated_inter(i_type, j_type, rr)
atomic_energy[:, atom_i] += pairwise_ene
anyangml marked this conversation as resolved.
Show resolved Hide resolved

return {"atomic_energy": atomic_energy}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the key should be "energy", please check your output def.
you may want to use this decorator to ensure the correctness of your atomic model output

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the key. Not sure about the decorator, this atomic model has no Fitting


def _pair_tabulated_inter(self, i_type: int, j_type: int, rr: torch.Tensor) -> torch.Tensor:
def _pair_tabulated_inter(self, i_type: torch.Tensor, j_type: torch.Tensor, rr: torch.Tensor) -> torch.Tensor:
"""Pairwise tabulated energy.

Parameters
----------
i_type : int
The integer representation of atom type for atom i.
i_type : torch.Tensor
The integer representation of atom type for atom i for all frames.

j_type : int
The integer representation of atom type for atom j.
j_type : torch.Tensor
The integer representation of atom type for atom j for all frames.

rr : torch.Tensor
The salar distance vector between two atoms.
The salar distance vector between two atoms for all frames.

Returns
-------
torch.Tensor
The energy between two atoms.
The energy between two atoms for all frames.

Raises
------
@@ -163,23 +167,26 @@ def _pair_tabulated_inter(self, i_type: int, j_type: int, rr: torch.Tensor) -> t

nspline = int(self.tab_info[2] + 0.1)

uu = (rr - rmin) * hi
uu = (rr - rmin) * hi # this is broadcasted to (nframes,1)

if uu < 0:
if any(uu < 0):
raise Exception("coord go beyond table lower boundary")

idx = int(uu)

if idx >= nspline:
ener = 0
# fscale = 0
return
idx = uu.to(torch.int)

uu -= idx
cur_tab = self.tab_data[i_type.squeeze(),j_type.squeeze()] # this should have shape (nframes, nspline, 4)


a3, a2, a1, a0 = self.tab_data[i_type][j_type][idx]
# if idx >= nspline:
# ener = 0
# # fscale = 0
# return
anyangml marked this conversation as resolved.
Show resolved Hide resolved

final_coef = cur_tab[torch.arange(idx.shape[0]), idx.squeeze()] # this should have shape (nframes, 4)
a3, a2, a1, a0 = final_coef[:,0], final_coef[:,1], final_coef[:,2], final_coef[:,3] # the four coefficients should all be (nframes, 1)

etmp = (a3 * uu + a2) * uu + a1
etmp = (a3 * uu + a2) * uu + a1 # this should be elementwise operations.
ener = etmp * uu + a0
return ener

@@ -190,22 +197,22 @@ def _get_pairwise_dist(coords: torch.Tensor) -> torch.Tensor:
Parameters
----------
coords : torch.Tensor
The coordinate of the atoms.
The coordinate of the atoms shape of (nframes * nall * 3).

Returns
-------
torch.Tensor
The pairwise distance between the atoms.
The pairwise distance between the atoms (nframes * nall * nall * 3).

Examples
--------
coords = torch.tensor([
coords = torch.tensor([[
[0,0,0],
[1,3,5],
[2,4,6]
])
]])

dist = tensor([
dist = tensor([[
[[ 0, 0, 0],
[-1, -3, -5],
[-2, -4, -6]],
@@ -217,7 +224,7 @@ def _get_pairwise_dist(coords: torch.Tensor) -> torch.Tensor:
[[ 2, 4, 6],
[ 1, 1, 1],
[ 0, 0, 0]]
])
]])
"""
return coords.unsqueeze(1) - coords.unsqueeze(0)
return coords.unsqueeze(2) - coords.unsqueeze(1)