Skip to content

Commit

Permalink
fix: UTs, removed old impl
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml committed Feb 26, 2024
1 parent c0af6fa commit 8a8107c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 80 deletions.
66 changes: 25 additions & 41 deletions deepmd/pt/model/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
self.precision = precision
self.prec = PRECISION_DICT[self.precision]
self.resnet_dt = resnet_dt
self.old_impl = old_impl
self.old_impl = False # this does not support old implementation.
self.exclude_types = exclude_types
self.ntypes = len(sel)
self.emask = PairExcludeMask(len(sel), exclude_types=exclude_types)
Expand Down Expand Up @@ -269,55 +269,39 @@ def forward(
self.rcut,
self.rcut_smth,
)
assert dmatrix.shape == (2, 3, 7, 1)

if self.old_impl:
assert self.filter_layers_old is not None
dmatrix = dmatrix.view(
-1, self.ndescrpt
) # shape is [nframes*nall, self.ndescrpt]
xyz_scatter = torch.empty(
1,
device=env.DEVICE,
)
ret = self.filter_layers_old[0](dmatrix)
xyz_scatter = ret
for ii, transform in enumerate(self.filter_layers_old[1:]):
# shape is [nframes*nall, 1, self.filter_neuron[-1]]
ret = transform.forward(dmatrix)
xyz_scatter = xyz_scatter + ret
else:
assert self.filter_layers is not None
dmatrix = dmatrix.view(-1, self.nnei, 1)
dmatrix = dmatrix.to(dtype=self.prec)
nfnl = dmatrix.shape[0]
# pre-allocate a shape to pass jit
xyz_scatter = torch.zeros(
[nfnl, 1, self.filter_neuron[-1]], dtype=self.prec, device=env.DEVICE
)

assert self.filter_layers is not None
dmatrix = dmatrix.view(-1, self.nnei, 1)
dmatrix = dmatrix.to(dtype=self.prec)
nfnl = dmatrix.shape[0]
# pre-allocate a shape to pass jit
xyz_scatter = torch.zeros(
[nfnl, 1, self.filter_neuron[-1]], dtype=self.prec, device=env.DEVICE
)

# nfnl x nnei
exclude_mask = self.emask(nlist, atype_ext).view(nfnl, -1)
for ii, ll in enumerate(self.filter_layers.networks):
# nfnl x nt
mm = exclude_mask[:, self.sec[ii] : self.sec[ii + 1]]
# nfnl x nt x 1
rr = dmatrix[:, self.sec[ii] : self.sec[ii + 1], :]
rr = rr * mm[:, :, None]
ss = rr[:, :, :1]
# nfnl x nt x ng
gg = ll.forward(ss)
# nfnl x 1 x ng
gr = torch.matmul(rr.permute(0, 2, 1), gg)
xyz_scatter += gr
# nfnl x nnei
exclude_mask = self.emask(nlist, atype_ext).view(nfnl, -1)
for ii, ll in enumerate(self.filter_layers.networks):
# nfnl x nt
mm = exclude_mask[:, self.sec[ii] : self.sec[ii + 1]]
# nfnl x nt x 1
rr = dmatrix[:, self.sec[ii] : self.sec[ii + 1], :]
rr = rr * mm[:, :, None]
ss = rr[:, :, :1]
# nfnl x nt x ng
gg = ll.forward(ss)
# nfnl x 1 x ng
gr = torch.matmul(rr.permute(0, 2, 1), gg)
xyz_scatter += gr

xyz_scatter /= self.nnei
xyz_scatter_1 = xyz_scatter.permute(0, 2, 1)

result = torch.matmul(
xyz_scatter_1, xyz_scatter
) # shape is [nframes*nall, self.filter_neuron[-1], 1]
result = result.view(-1, nloc, self.filter_neuron[-1] * 1)
result = result.view(-1, nloc, self.filter_neuron[-1] * self.filter_neuron[-1])
return (
result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
None,
Expand Down
40 changes: 1 addition & 39 deletions source/tests/pt/model/test_descriptor_se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,45 +102,7 @@ def test_consistency(
# atol=atol,
# err_msg=err_msg,
# )
# old impl
if idt is False and prec == "float64":
dd3 = DescrptSeR(
self.rcut,
self.rcut_smth,
self.sel,
precision=prec,
resnet_dt=idt,
old_impl=True,
).to(env.DEVICE)
dd0_state_dict = dd0.state_dict()
dd3_state_dict = dd3.state_dict()
for i in dd3_state_dict:
dd3_state_dict[i] = (
dd0_state_dict[
i.replace(".deep_layers.", ".layers.").replace(
"filter_layers_old.", "filter_layers.networks."
)
]
.detach()
.clone()
)
if ".bias" in i:
dd3_state_dict[i] = dd3_state_dict[i].unsqueeze(0)
dd3.load_state_dict(dd3_state_dict)

rd3, gr3, _, _, sw3 = dd3(
torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE),
torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE),
torch.tensor(self.nlist, dtype=int, device=env.DEVICE),
)
for aa, bb in zip([rd1, gr1, sw1], [rd3, gr3, sw3]):
np.testing.assert_allclose(
aa.detach().cpu().numpy(),
bb.detach().cpu().numpy(),
rtol=rtol,
atol=atol,
err_msg=err_msg,
)


def test_jit(
self,
Expand Down

0 comments on commit 8a8107c

Please sign in to comment.