Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

breaking: pt: remove data preprocess from data stat #3261

Merged
merged 15 commits into from
Feb 13, 2024
8 changes: 3 additions & 5 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from typing import (
List,
Optional,
Expand All @@ -18,8 +17,6 @@
DescrptBlockSeAtten,
)

log = logging.getLogger(__name__)


@Descriptor.register("dpa1")
@Descriptor.register("se_atten")
Expand Down Expand Up @@ -112,7 +109,7 @@ def distinguish_types(self) -> bool:
"""Returns if the descriptor requires a neighbor list that distinguish different
atomic types or not.
"""
return False
return self.se_atten.distinguish_types()

@property
def dim_out(self):
Expand All @@ -128,7 +125,7 @@ def compute_input_stats(self, merged):
def init_desc_stat(
self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs
):
assert True not in [x is None for x in [sumr, suma, sumn, sumr2, suma2]]
assert all(x is not None for x in [sumr, suma, sumn, sumr2, suma2])
self.se_atten.init_desc_stat(sumr, suma, sumn, sumr2, suma2)

@classmethod
Expand All @@ -141,6 +138,7 @@ def get_stat_name(
"""
descrpt_type = type_name
assert descrpt_type in ["dpa1", "se_atten"]
assert all(x is not None for x in [rcut, rcut_smth, sel])
return f"stat_file_descrpt_dpa1_rcut{rcut:.2f}_smth{rcut_smth:.2f}_sel{sel}_ntypes{ntypes}.npz"

@classmethod
Expand Down
11 changes: 4 additions & 7 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from typing import (
List,
Optional,
Expand Down Expand Up @@ -27,8 +26,6 @@
DescrptBlockSeAtten,
)

log = logging.getLogger(__name__)


@Descriptor.register("dpa2")
class DescrptDPA2(Descriptor):
Expand Down Expand Up @@ -316,7 +313,7 @@ def compute_input_stats(self, merged):
def init_desc_stat(
self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs
):
assert True not in [x is None for x in [sumr, suma, sumn, sumr2, suma2]]
assert all(x is not None for x in [sumr, suma, sumn, sumr2, suma2])
for ii, descrpt in enumerate([self.repinit, self.repformers]):
stat_dict_ii = {
"sumr": sumr[ii],
Expand Down Expand Up @@ -346,8 +343,8 @@ def get_stat_name(
"""
descrpt_type = type_name
assert descrpt_type in ["dpa2"]
assert True not in [
x is None
assert all(
x is not None
for x in [
repinit_rcut,
repinit_rcut_smth,
Expand All @@ -356,7 +353,7 @@ def get_stat_name(
repformer_rcut_smth,
repformer_nsel,
]
]
)
return (
f"stat_file_descrpt_dpa2_repinit_rcut{repinit_rcut:.2f}_smth{repinit_rcut_smth:.2f}_sel{repinit_nsel}"
f"_repformer_rcut{repformer_rcut:.2f}_smth{repformer_rcut_smth:.2f}_sel{repformer_nsel}_ntypes{ntypes}.npz"
Expand Down
10 changes: 9 additions & 1 deletion deepmd/pt/model/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@
def get_dim_emb(self):
return self.dim_emb

def distinguish_types(self) -> bool:
"""Returns if the descriptor requires a neighbor list that distinguish different
atomic types or not.
"""
return any(
descriptor.distinguish_types() for descriptor in self.descriptor_list
)

@property
def dim_out(self):
"""Returns the output dimension of this descriptor."""
Expand Down Expand Up @@ -170,7 +178,7 @@
def init_desc_stat(
self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs
):
assert True not in [x is None for x in [sumr, suma, sumn, sumr2, suma2]]
assert all(x is not None for x in [sumr, suma, sumn, sumr2, suma2])

Check warning on line 181 in deepmd/pt/model/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/hybrid.py#L181

Added line #L181 was not covered by tests
for ii, descrpt in enumerate(self.descriptor_list):
stat_dict_ii = {
"sumr": sumr[ii],
Expand Down
46 changes: 27 additions & 19 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
env,
)
from deepmd.pt.utils.nlist import (
build_neighbor_list,
extend_input_and_build_neighbor_list,
)
from deepmd.pt.utils.utils import (
get_activation_fn,
Expand Down Expand Up @@ -178,6 +178,12 @@
"""Returns the embedding dimension g2."""
return self.g2_dim

def distinguish_types(self) -> bool:
"""Returns if the descriptor requires a neighbor list that distinguish different
atomic types or not.
"""
return False

@property
def dim_out(self):
"""Returns the output dimension of this descriptor."""
Expand Down Expand Up @@ -272,44 +278,46 @@
suma2 = []
mixed_type = "real_natoms_vec" in merged[0]
for system in merged:
index = system["mapping"].unsqueeze(-1).expand(-1, -1, 3)
extended_coord = torch.gather(system["coord"], dim=1, index=index)
extended_coord = extended_coord - system["shift"]
index = system["mapping"]
extended_atype = torch.gather(system["atype"], dim=1, index=index)
nloc = system["atype"].shape[-1]
#######################################################
# dirty hack here! the interface of dataload should be
# redesigned to support descriptors like dpa2
#######################################################
nlist = build_neighbor_list(
coord, atype, box, natoms = (
system["coord"],
system["atype"],
system["box"],
system["natoms"],
)
(
extended_coord,
extended_atype,
nloc,
self.rcut,
mapping,
nlist,
) = extend_input_and_build_neighbor_list(
coord,
atype,
self.get_rcut(),
self.get_sel(),
distinguish_types=False,
distinguish_types=self.distinguish_types(),
box=box,
)
env_mat, _, _ = prod_env_mat_se_a(
extended_coord,
nlist,
system["atype"],
atype,
self.mean,
self.stddev,
self.rcut,
self.rcut_smth,
)
if not mixed_type:
sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt(
env_mat.detach().cpu().numpy(), ndescrpt, system["natoms"]
env_mat.detach().cpu().numpy(), ndescrpt, natoms
)
else:
real_natoms_vec = system["real_natoms_vec"]

Check warning on line 314 in deepmd/pt/model/descriptor/repformers.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/repformers.py#L314

Added line #L314 was not covered by tests
sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt(
env_mat.detach().cpu().numpy(),
ndescrpt,
system["real_natoms_vec"],
real_natoms_vec,
mixed_type=mixed_type,
real_atype=system["atype"].detach().cpu().numpy(),
real_atype=atype.detach().cpu().numpy(),
)
sumr.append(sysr)
suma.append(sysa)
Expand Down
46 changes: 34 additions & 12 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from typing import (
ClassVar,
List,
Expand Down Expand Up @@ -37,8 +36,9 @@
from deepmd.pt.model.network.network import (
TypeFilter,
)

log = logging.getLogger(__name__)
from deepmd.pt.utils.nlist import (
extend_input_and_build_neighbor_list,
)


@Descriptor.register("se_e2_a")
Expand Down Expand Up @@ -100,7 +100,7 @@ def distinguish_types(self):
"""Returns if the descriptor requires a neighbor list that distinguish different
atomic types or not.
"""
return True
return self.sea.distinguish_types()

@property
def dim_out(self):
Expand All @@ -114,7 +114,7 @@ def compute_input_stats(self, merged):
def init_desc_stat(
self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs
):
assert True not in [x is None for x in [sumr, suma, sumn, sumr2, suma2]]
assert all(x is not None for x in [sumr, suma, sumn, sumr2, suma2])
self.sea.init_desc_stat(sumr, suma, sumn, sumr2, suma2)

@classmethod
Expand All @@ -127,7 +127,7 @@ def get_stat_name(
"""
descrpt_type = type_name
assert descrpt_type in ["se_e2_a"]
assert True not in [x is None for x in [rcut, rcut_smth, sel]]
assert all(x is not None for x in [rcut, rcut_smth, sel])
return f"stat_file_descrpt_sea_rcut{rcut:.2f}_smth{rcut_smth:.2f}_sel{sel}_ntypes{ntypes}.npz"

@classmethod
Expand Down Expand Up @@ -347,6 +347,12 @@ def get_dim_in(self) -> int:
"""Returns the input dimension."""
return self.dim_in

def distinguish_types(self) -> bool:
"""Returns if the descriptor requires a neighbor list that distinguish different
atomic types or not.
"""
return True

@property
def dim_out(self):
"""Returns the output dimension of this descriptor."""
Expand Down Expand Up @@ -381,20 +387,36 @@ def compute_input_stats(self, merged):
sumr2 = []
suma2 = []
for system in merged:
index = system["mapping"].unsqueeze(-1).expand(-1, -1, 3)
extended_coord = torch.gather(system["coord"], dim=1, index=index)
extended_coord = extended_coord - system["shift"]
coord, atype, box, natoms = (
system["coord"],
system["atype"],
system["box"],
system["natoms"],
)
(
extended_coord,
extended_atype,
mapping,
nlist,
) = extend_input_and_build_neighbor_list(
coord,
atype,
self.get_rcut(),
self.get_sel(),
distinguish_types=self.distinguish_types(),
box=box,
)
env_mat, _, _ = prod_env_mat_se_a(
extended_coord,
system["nlist"],
system["atype"],
nlist,
atype,
self.mean,
self.stddev,
self.rcut,
self.rcut_smth,
)
sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt(
env_mat.detach().cpu().numpy(), self.ndescrpt, system["natoms"]
env_mat.detach().cpu().numpy(), self.ndescrpt, natoms
)
sumr.append(sysr)
suma.append(sysa)
Expand Down
42 changes: 34 additions & 8 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.nlist import (
extend_input_and_build_neighbor_list,
)


@DescriptorBlock.register("se_atten")
Expand Down Expand Up @@ -161,6 +164,12 @@
"""Returns the output dimension of embedding."""
return self.filter_neuron[-1]

def distinguish_types(self) -> bool:
"""Returns if the descriptor requires a neighbor list that distinguish different
atomic types or not.
"""
return False

@property
def dim_out(self):
"""Returns the output dimension of this descriptor."""
Expand All @@ -185,29 +194,46 @@
suma2 = []
mixed_type = "real_natoms_vec" in merged[0]
for system in merged:
index = system["mapping"].unsqueeze(-1).expand(-1, -1, 3)
extended_coord = torch.gather(system["coord"], dim=1, index=index)
extended_coord = extended_coord - system["shift"]
coord, atype, box, natoms = (
system["coord"],
system["atype"],
system["box"],
system["natoms"],
)
(
extended_coord,
extended_atype,
mapping,
nlist,
) = extend_input_and_build_neighbor_list(
coord,
atype,
self.get_rcut(),
self.get_sel(),
distinguish_types=self.distinguish_types(),
box=box,
)
env_mat, _, _ = prod_env_mat_se_a(
extended_coord,
system["nlist"],
system["atype"],
nlist,
atype,
self.mean,
self.stddev,
self.rcut,
self.rcut_smth,
)
if not mixed_type:
sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt(
env_mat.detach().cpu().numpy(), self.ndescrpt, system["natoms"]
env_mat.detach().cpu().numpy(), self.ndescrpt, natoms
)
else:
real_natoms_vec = system["real_natoms_vec"]

Check warning on line 230 in deepmd/pt/model/descriptor/se_atten.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_atten.py#L230

Added line #L230 was not covered by tests
sysr, sysr2, sysa, sysa2, sysn = analyze_descrpt(
env_mat.detach().cpu().numpy(),
self.ndescrpt,
system["real_natoms_vec"],
real_natoms_vec,
mixed_type=mixed_type,
real_atype=system["atype"].detach().cpu().numpy(),
real_atype=atype.detach().cpu().numpy(),
)
sumr.append(sysr)
suma.append(sysa)
Expand Down
Loading