Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: update sel by statistics #3348

Merged
merged 9 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,22 @@ def deserialize(cls, data: dict) -> "BD":
return BD.get_class_by_type(data["type"]).deserialize(data)
raise NotImplementedError("Not implemented in class %s" % cls.__name__)

@classmethod
@abstractmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
# call subprocess
cls = cls.get_class_by_type(j_get_type(local_jdata, cls.__name__))
return cls.update_sel(global_jdata, local_jdata)

setattr(BD, fwd_method_name, BD.fwd)
delattr(BD, "fwd")

Expand Down
17 changes: 17 additions & 0 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

import numpy as np

from deepmd.dpmodel.utils.update_sel import (
UpdateSel,
)
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)
Expand Down Expand Up @@ -388,3 +391,17 @@ def deserialize(cls, data: dict) -> "DescrptSeA":
obj.embeddings = NetworkCollection.deserialize(embeddings)
obj.env_mat = EnvMat.deserialize(env_mat)
return obj

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, False)
17 changes: 17 additions & 0 deletions deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import numpy as np

from deepmd.dpmodel.utils.update_sel import (
UpdateSel,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -324,3 +327,17 @@ def deserialize(cls, data: dict) -> "DescrptSeR":
obj.embeddings = NetworkCollection.deserialize(embeddings)
obj.env_mat = EnvMat.deserialize(env_mat)
return obj

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, False)
15 changes: 15 additions & 0 deletions deepmd/dpmodel/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,21 @@ def get_nsel(self) -> int:
"""Returns the total number of selected neighboring atoms in the cut-off radius."""
pass

@classmethod
@abstractmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
cls = cls.get_class_by_type(local_jdata.get("type", "standard"))
return cls.update_sel(global_jdata, local_jdata)

return BaseBaseModel


Expand Down
20 changes: 19 additions & 1 deletion deepmd/dpmodel/model/dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from deepmd.dpmodel.atomic_model import (
DPAtomicModel,
)
from deepmd.dpmodel.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.dpmodel.model.base_model import (
BaseModel,
)
Expand All @@ -14,4 +17,19 @@
# use "class" to resolve "Variable not allowed in type expression"
@BaseModel.register("standard")
class DPModel(make_model(DPAtomicModel), BaseModel):
pass
@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy["descriptor"] = BaseDescriptor.update_sel(
global_jdata, local_jdata["descriptor"]
)
return local_jdata_cpy
21 changes: 21 additions & 0 deletions deepmd/dpmodel/utils/update_sel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Type,
)

from deepmd.dpmodel.utils.neighbor_stat import (
NeighborStat,
)
from deepmd.utils.update_sel import (
BaseUpdateSel,
)


class UpdateSel(BaseUpdateSel):
@property
def neighbor_stat(self) -> Type[NeighborStat]:
return NeighborStat

def hook(self, min_nbor_dist, max_nbor_size):
# TODO: save to the model
pass
2 changes: 1 addition & 1 deletion deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def main_parser() -> argparse.ArgumentParser:
parser_train.add_argument(
"--skip-neighbor-stat",
action="store_true",
help="(Supported backend: TensorFlow) Skip calculating neighbor statistics. Sel checking, automatic sel, and model compression will be disabled.",
help="Skip calculating neighbor statistics. Sel checking, automatic sel, and model compression will be disabled.",
)
parser_train.add_argument(
# -m has been used by mpi-log
Expand Down
9 changes: 9 additions & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
from deepmd.pt.infer import (
inference,
)
from deepmd.pt.model.model import (
BaseModel,
)
from deepmd.pt.train import (
training,
)
Expand Down Expand Up @@ -238,6 +241,12 @@ def train(FLAGS):
SummaryPrinter()()
with open(FLAGS.INPUT) as fin:
config = json.load(fin)
if not FLAGS.skip_neighbor_stat:
log.info(
"Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)"
)
config["model"] = BaseModel.update_sel(config, config["model"])

trainer = get_trainer(
config,
FLAGS.init_model,
Expand Down
17 changes: 17 additions & 0 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from deepmd.pt.model.network.network import (
TypeEmbedNet,
)
from deepmd.pt.utils.update_sel import (
UpdateSel,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -193,3 +196,17 @@ def forward(
g1 = torch.cat([g1, g1_inp], dim=-1)

return g1, rot_mat, g2, h2, sw

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, True)
32 changes: 32 additions & 0 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
build_multiple_neighbor_list,
get_multiple_nlist_key,
)
from deepmd.pt.utils.update_sel import (
UpdateSel,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -396,3 +399,32 @@ def forward(
if self.concat_output_tebd:
g1 = torch.cat([g1, g1_inp], dim=-1)
return g1, rot_mat, g2, h2, sw

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
update_sel = UpdateSel()
local_jdata_cpy = update_sel.update_one_sel(
global_jdata,
local_jdata_cpy,
True,
rcut_key="repinit_rcut",
sel_key="repinit_nsel",
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
)
local_jdata_cpy = update_sel.update_one_sel(
global_jdata,
local_jdata_cpy,
True,
rcut_key="repformer_rcut",
sel_key="repformer_nsel",
)
return local_jdata_cpy
17 changes: 17 additions & 0 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
from deepmd.pt.utils.env_mat_stat import (
EnvMatStatSe,
)
from deepmd.pt.utils.update_sel import (
UpdateSel,
)
from deepmd.utils.env_mat_stat import (
StatItem,
)
Expand Down Expand Up @@ -228,6 +231,20 @@ def t_cvt(xx):
obj.sea.filter_layers = NetworkCollection.deserialize(embeddings)
return obj

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, False)


@DescriptorBlock.register("se_e2_a")
class DescrptBlockSeA(DescriptorBlock):
Expand Down
17 changes: 17 additions & 0 deletions deepmd/pt/model/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
from deepmd.pt.utils.exclude_mask import (
PairExcludeMask,
)
from deepmd.pt.utils.update_sel import (
UpdateSel,
)
from deepmd.utils.env_mat_stat import (
StatItem,
)
Expand Down Expand Up @@ -319,3 +322,17 @@ def t_cvt(xx):
obj["dstd"] = t_cvt(variables["dstd"])
obj.filter_layers = NetworkCollection.deserialize(embeddings)
return obj

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, False)
20 changes: 20 additions & 0 deletions deepmd/pt/model/model/dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from deepmd.pt.model.atomic_model import (
DPAtomicModel,
)
from deepmd.pt.model.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.pt.model.model.model import (
BaseModel,
)
Expand Down Expand Up @@ -47,3 +50,20 @@ def __new__(cls, descriptor, fitting, *args, **kwargs):
cls = PolarModel
# else: unknown fitting type, fall back to DPModel
return super().__new__(cls)

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy["descriptor"] = BaseDescriptor.update_sel(
global_jdata, local_jdata["descriptor"]
)
return local_jdata_cpy
20 changes: 20 additions & 0 deletions deepmd/pt/model/model/dp_zbl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import torch

from deepmd.dpmodel.model.dp_model import (
DPModel,
)
from deepmd.pt.model.atomic_model import (
DPZBLLinearAtomicModel,
)
Expand Down Expand Up @@ -97,3 +100,20 @@ def forward_lower(
model_predict["dforce"] = model_ret["dforce"]
model_predict = model_ret
return model_predict

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy["dpmodel"] = DPModel.update_sel(
global_jdata, local_jdata["dpmodel"]
)
return local_jdata_cpy
21 changes: 21 additions & 0 deletions deepmd/pt/utils/update_sel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Type,
)

from deepmd.pt.utils.neighbor_stat import (
NeighborStat,
)
from deepmd.utils.update_sel import (
BaseUpdateSel,
)


class UpdateSel(BaseUpdateSel):
@property
def neighbor_stat(self) -> Type[NeighborStat]:
return NeighborStat

def hook(self, min_nbor_dist, max_nbor_size):
# TODO: save to the model
pass
Loading