Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: pairtab model pytorch #174

Closed
wants to merge 18 commits into from
Prev Previous commit
Next Next commit
fix: refactor cubic spline coefficient extraction
Anyang Peng authored and Anyang Peng committed Jan 23, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit e8103224887018785dd511c36340e759941d7e61
56 changes: 23 additions & 33 deletions deepmd_pt/model/model/pair_tab.py
Original file line number Diff line number Diff line change
@@ -93,8 +93,8 @@ def get_sel(self)->int:
return self.sel

def distinguish_types(self)->bool:
# this model has no descriptor, thus no type_split.
return
# to match DPA1 and DPA2.
return False

def forward_atomic(
self,
@@ -104,42 +104,39 @@ def forward_atomic(
mapping: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
) -> Dict[str, torch.Tensor]:

#this should get atomic energy for all local atoms?


nframes, nloc, nnei = nlist.shape
# atype = extended_atype[:, :nloc]
atype = extended_atype[:, :nloc] #this is the atype for local atoms, nframes * nloc
pairwise_dr = self._get_pairwise_dist(extended_coord)
anyangml marked this conversation as resolved.
Show resolved Hide resolved

"""
below is the pseudocode, need to figure out how the index works.

atomic_energy = torch.zeros(nloc)
for a_loc in range(nloc):

for a_nei in range(nnei):
# there will be duplicated calculation (pairwise), maybe cache it somewhere.
# removing _pair_tab_jloop method, just unwrap here.
self.tab_data = self.tab_data.reshape(self.tab.ntypes,self.tab.ntypes,self.tab.nspline,4)

cur_table_data --> sub-table based on atype.
atomic_energy = torch.zeros(nloc)

rr = pairwise_dr[a_loc][a_nei].pow(2).sum().sqrt() # this is the salar distance.
pairwise_ene = self._pair_tabulated_inter(cur_table_data, rr)
atomic_energy[a_loc] += pairwise_ene

return {"atomic_energy": atomic_energy} --> convert to FittingOutputDef
for atom_i in range(nloc):
i_type = atype[:,i_type] # not quite sure about this on frame dimension
for atom_j in range(nnei):
j_idx = nnei[atom_j]
j_type = extended_atype[:,j_idx] #same here
rr = pairwise_dr[atom_i][atom_j].pow(2).sum().sqrt()

"""
# need to handle i_type and j_type frame dimension
pairwise_ene = self._pair_tabulated_inter(i_type,j_type,rr)
atomic_energy[atom_i] += pairwise_ene

return {"atomic_energy": atomic_energy}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the key should be "energy", please check your output def.
you may want to use this decorator to ensure the correctness of your atomic model output

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the key. Not sure about the decorator, this atomic model has no Fitting


def _pair_tabulated_inter(self, cur_table_data: torch.Tensor, rr: torch.Tensor) -> torch.Tensor:
def _pair_tabulated_inter(self, i_type: int, j_type: int, rr: torch.Tensor) -> torch.Tensor:
"""Pairwise tabulated energy.

Parameters
----------
cur_table_data : torch.Tensor
The tabulated cubic spline coefficients for the current atom types.
i_type : int
The integer representation of atom type for atom i.

j_type : int
The integer representation of atom type for atom j.
anyangml marked this conversation as resolved.
Show resolved Hide resolved

rr : torch.Tensor
The salar distance vector between two atoms.
@@ -165,8 +162,6 @@ def _pair_tabulated_inter(self, cur_table_data: torch.Tensor, rr: torch.Tensor)
hi = 1. / hh

nspline = int(self.tab_info[2] + 0.1)
# ndata = nspline * 4


uu = (rr - rmin) * hi

@@ -182,10 +177,7 @@ def _pair_tabulated_inter(self, cur_table_data: torch.Tensor, rr: torch.Tensor)

uu -= idx

a3 = cur_table_data[4 * idx + 0]
a2 = cur_table_data[4 * idx + 1]
a1 = cur_table_data[4 * idx + 2]
a0 = cur_table_data[4 * idx + 3]
a3, a2, a1, a0 = self.tab_data[i_type][j_type][idx]

etmp = (a3 * uu + a2) * uu + a1
ener = etmp * uu + a0
@@ -207,7 +199,6 @@ def _get_pairwise_dist(coords: torch.Tensor) -> torch.Tensor:

Examples
--------

coords = torch.tensor([
[0,0,0],
[1,3,5],
@@ -227,7 +218,6 @@ def _get_pairwise_dist(coords: torch.Tensor) -> torch.Tensor:
[ 1, 1, 1],
[ 0, 0, 0]]
])

"""
return coords.unsqueeze(1) - coords.unsqueeze(0)