Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add pairwise DPRc #2682

Merged
merged 57 commits into from
Jul 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
061c296
init pairwise dprc
njzjz May 26, 2023
a363430
add tests for group_atoms
njzjz May 30, 2023
edf282a
add tests and fix bugs
njzjz Jun 2, 2023
dcb3fed
add tests to test the OP
njzjz Jun 2, 2023
9c64194
fix m_qmmm_frame_idx
njzjz Jun 2, 2023
a1364ff
Merge remote-tracking branch 'origin/devel' into pairwise_dprc
njzjz Jun 5, 2023
1fad763
fix index out of range
njzjz Jun 6, 2023
b19165f
Merge remote-tracking branch 'origin/devel' into pairwise_dprc
njzjz Jun 6, 2023
30b224c
fix model and add tests
njzjz Jun 7, 2023
0b3be1f
skip ut if tf<1.15
njzjz Jun 7, 2023
a81f86b
fix test paths
njzjz Jun 8, 2023
9d08149
fix training errors; add examples
njzjz Jun 8, 2023
0fd198e
build type embedding only once
njzjz Jun 8, 2023
b7d33f0
fix self.ntypes
njzjz Jun 8, 2023
48d845e
Merge branch 'devel' into pairwise_dprc
njzjz Jun 19, 2023
8a94ebb
Merge branch 'devel' into pairwise_dprc
njzjz Jun 21, 2023
d1b5142
add examples for the normal model
njzjz Jun 12, 2023
d9af6ba
make the example models compressible
njzjz Jun 23, 2023
35797eb
fix se_atten variable names when suffix is given
njzjz Jun 23, 2023
837614f
Update se_atten.py
njzjz Jun 23, 2023
33adbb2
Update se_atten.py
njzjz Jun 23, 2023
a25d57f
fix output and init_variables
njzjz Jun 24, 2023
e5e6f31
Merge remote-tracking branch 'fork/fix-se_atten-suffix' into pairwise…
njzjz Jun 24, 2023
089b1b1
support compression
njzjz Jun 25, 2023
105086c
fix se_atten compression when suffix is given
njzjz Jun 25, 2023
df56888
Merge branch 'devel' into pairwise_dprc
njzjz Jun 26, 2023
dba4fde
change the mesh
njzjz Jun 26, 2023
3973b5b
add docs; improve example
njzjz Jul 3, 2023
294fb64
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 3, 2023
56c10d7
docs: update equations
njzjz Jul 5, 2023
fdb4c57
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 5, 2023
5e2a5a4
add compress information
njzjz Jul 7, 2023
3b2b591
add t_mesh to nodes
njzjz Jul 7, 2023
d39ea83
fix the dtype of rcut
njzjz Jul 7, 2023
6d83ef7
fix a typo in reshape
njzjz Jul 7, 2023
4a6b891
another bug fixed
njzjz Jul 7, 2023
9e309f5
Merge remote-tracking branch 'origin/devel' into pairwise_dprc
njzjz Jul 9, 2023
f966356
support fparam/aparam in dp model-devi
njzjz Jul 9, 2023
188ce95
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 9, 2023
4260309
fix tests
njzjz Jul 9, 2023
4591101
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 9, 2023
067addf
sort aparam in the C++ interface
njzjz Jul 10, 2023
1e49f6d
fix typo
njzjz Jul 10, 2023
a2aa14e
sort aparam in the Python API
njzjz Jul 10, 2023
86997ae
Merge remote-tracking branch 'fork/model-devi-fparam-aparam' into pai…
njzjz Jul 10, 2023
f4f7ec1
update the link in README
njzjz Jul 11, 2023
f2b0fa6
fix se_atten tabulate when exclude_types is given
njzjz Jul 14, 2023
d3df072
add tests for compression
njzjz Jul 14, 2023
f11c5ea
fix path to data
njzjz Jul 14, 2023
4121e8a
skip tf 1.14
njzjz Jul 14, 2023
514d25a
fix model name
njzjz Jul 15, 2023
1a4eb5d
fix TabulateFusionSeAGradGradOp
njzjz Jul 15, 2023
b881afc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 15, 2023
38744bf
fetch attr in the OP
njzjz Jul 15, 2023
4663813
Merge branch 'devel' into pairwise_dprc
njzjz Jul 15, 2023
06dddcf
fix compress training
njzjz Jul 17, 2023
f41da34
merge documentation
njzjz Jul 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 0 additions & 42 deletions deepmd/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,48 +324,6 @@ def prod_force_virial(
The atomic virial
"""

def get_feed_dict(
self,
coord_: tf.Tensor,
atype_: tf.Tensor,
natoms: tf.Tensor,
box: tf.Tensor,
mesh: tf.Tensor,
) -> Dict[str, tf.Tensor]:
"""Generate the feed_dict for current descriptor.

Parameters
----------
coord_ : tf.Tensor
The coordinate of atoms
atype_ : tf.Tensor
The type of atoms
natoms : tf.Tensor
The number of atoms. This tensor has the length of Ntypes + 2
natoms[0]: number of local atoms
natoms[1]: total number of atoms held by this processor
natoms[i]: 2 <= i < Ntypes+2, number of type i atoms
box : tf.Tensor
The box. Can be generated by deepmd.model.make_stat_input
mesh : tf.Tensor
For historical reasons, only the length of the Tensor matters.
if size of mesh == 6, pbc is assumed.
if size of mesh == 0, no-pbc is assumed.

Returns
-------
feed_dict : dict[str, tf.Tensor]
The output feed_dict of current descriptor
"""
feed_dict = {
"t_coord:0": coord_,
"t_type:0": atype_,
"t_natoms:0": natoms,
"t_box:0": box,
"t_mesh:0": mesh,
}
return feed_dict

def init_variables(
self,
graph: tf.Graph,
Expand Down
1 change: 1 addition & 0 deletions deepmd/entrypoints/freeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def _make_node_names(
"model_attr/model_version",
"train_attr/min_nbor_dist",
"train_attr/training_script",
"t_mesh",
]

if model_type == "ener":
Expand Down
10 changes: 10 additions & 0 deletions deepmd/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,11 @@ def get_modifier(modi_data=None):


def get_rcut(jdata):
if jdata["model"].get("type") == "pairwise_dprc":
return max(
jdata["model"]["qm_model"]["descriptor"]["rcut"],
jdata["model"]["qmmm_model"]["descriptor"]["rcut"],
)
descrpt_data = jdata["model"]["descriptor"]
rcut_list = []
if descrpt_data["type"] == "hybrid":
Expand Down Expand Up @@ -499,6 +504,11 @@ def update_sel(jdata):
log.info(
"Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)"
)
if jdata["model"].get("type") == "pairwise_dprc":
# do not update sel; only find min distance
rcut = get_rcut(jdata)
get_min_nbor_dist(jdata, rcut)
return jdata
descrpt_data = jdata["model"]["descriptor"]
if descrpt_data["type"] == "hybrid":
for ii in range(len(descrpt_data["list"])):
Expand Down
7 changes: 5 additions & 2 deletions deepmd/model/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def build(
input_dict["nframes"] = tf.shape(coord)[0]

# type embedding if any
if self.typeebd is not None:
if self.typeebd is not None and "type_embedding" not in input_dict:
type_embedding = self.typeebd.build(
self.ntypes,
reuse=reuse,
Expand Down Expand Up @@ -368,7 +368,10 @@ def init_variables(
tf.constant("compressed_model", name="model_type", dtype=tf.string)
else:
raise RuntimeError("Unknown model type %s" % model_type)
if self.typeebd is not None:
if (
self.typeebd is not None
and self.typeebd.type_embedding_net_variables is None
):
self.typeebd.init_variables(
graph, graph_def, suffix=suffix, model_type=model_type
)
Expand Down
112 changes: 100 additions & 12 deletions deepmd/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Enum,
)
from typing import (
Dict,
List,
Optional,
Union,
Expand Down Expand Up @@ -82,12 +83,17 @@ def __new__(cls, *args, **kwargs):
from deepmd.model.multi import (
MultiModel,
)
from deepmd.model.pairwise_dprc import (
PairwiseDPRc,
)

model_type = kwargs.get("type", "standard")
if model_type == "standard":
cls = StandardModel
elif model_type == "multi":
cls = MultiModel
elif model_type == "pairwise_dprc":
cls = PairwiseDPRc
else:
raise ValueError(f"unknown model type: {model_type}")
return cls.__new__(cls, *args, **kwargs)
Expand Down Expand Up @@ -261,14 +267,31 @@ def build_descrpt(
suffix=suffix,
reuse=reuse,
)
dout = tf.identity(dout, name="o_descriptor")
dout = tf.identity(dout, name="o_descriptor" + suffix)
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
else:
tf.constant(
self.rcut, name="descrpt_attr/rcut", dtype=GLOBAL_TF_FLOAT_PRECISION
self.rcut,
name="descrpt_attr%s/rcut" % suffix,
dtype=GLOBAL_TF_FLOAT_PRECISION,
)
tf.constant(
self.ntypes, name="descrpt_attr%s/ntypes" % suffix, dtype=tf.int32
)
tf.constant(self.ntypes, name="descrpt_attr/ntypes", dtype=tf.int32)
feed_dict = self.descrpt.get_feed_dict(coord_, atype_, natoms, box, mesh)
return_elements = [*self.descrpt.get_tensor_names(), "o_descriptor:0"]
if "global_feed_dict" in input_dict:
feed_dict = input_dict["global_feed_dict"]
else:
extra_feed_dict = {}
if "fparam" in input_dict:
extra_feed_dict["fparam"] = input_dict["fparam"]
if "aparam" in input_dict:
extra_feed_dict["aparam"] = input_dict["aparam"]
feed_dict = self.get_feed_dict(
coord_, atype_, natoms, box, mesh, **extra_feed_dict
)
return_elements = [
*self.descrpt.get_tensor_names(suffix=suffix),
"o_descriptor%s:0" % suffix,
]
if frz_model is not None:
imported_tensors = self._import_graph_def_from_frz_model(
frz_model, feed_dict, return_elements
Expand Down Expand Up @@ -343,8 +366,14 @@ def change_energy_bias(
"""
raise RuntimeError("Not supported")

def enable_compression(self):
"""Enable compression."""
def enable_compression(self, suffix: str = ""):
"""Enable compression.

Parameters
----------
suffix : str
suffix to name scope
"""
raise RuntimeError("Not supported")

def get_numb_fparam(self) -> Union[int, dict]:
Expand Down Expand Up @@ -379,6 +408,55 @@ def get_ntypes(self) -> int:
def data_stat(self, data: dict):
"""Data staticis."""

def get_feed_dict(
self,
coord_: tf.Tensor,
atype_: tf.Tensor,
natoms: tf.Tensor,
box: tf.Tensor,
mesh: tf.Tensor,
**kwargs,
) -> Dict[str, tf.Tensor]:
"""Generate the feed_dict for current descriptor.

Parameters
----------
coord_ : tf.Tensor
The coordinate of atoms
atype_ : tf.Tensor
The type of atoms
natoms : tf.Tensor
The number of atoms. This tensor has the length of Ntypes + 2
natoms[0]: number of local atoms
natoms[1]: total number of atoms held by this processor
natoms[i]: 2 <= i < Ntypes+2, number of type i atoms
box : tf.Tensor
The box. Can be generated by deepmd.model.make_stat_input
mesh : tf.Tensor
For historical reasons, only the length of the Tensor matters.
if size of mesh == 6, pbc is assumed.
if size of mesh == 0, no-pbc is assumed.
**kwargs : dict
The additional arguments

Returns
-------
feed_dict : dict[str, tf.Tensor]
The output feed_dict of current descriptor
"""
feed_dict = {
"t_coord:0": coord_,
"t_type:0": atype_,
"t_natoms:0": natoms,
"t_box:0": box,
"t_mesh:0": mesh,
}
if kwargs.get("fparam") is not None:
feed_dict["t_fparam:0"] = kwargs["fparam"]
if kwargs.get("aparam") is not None:
feed_dict["t_aparam:0"] = kwargs["aparam"]
return feed_dict


class StandardModel(Model):
"""Standard model, which must contain a descriptor and a fitting.
Expand Down Expand Up @@ -479,8 +557,14 @@ def enable_mixed_precision(self, mixed_prec: dict):
self.descrpt.enable_mixed_precision(mixed_prec)
self.fitting.enable_mixed_precision(mixed_prec)

def enable_compression(self):
"""Enable compression."""
def enable_compression(self, suffix: str = ""):
"""Enable compression.

Parameters
----------
suffix : str
suffix to name scope
"""
graph, graph_def = load_graph_def(self.compress["model_file"])
self.descrpt.enable_compression(
self.compress["min_nbor_dist"],
Expand All @@ -490,11 +574,15 @@ def enable_compression(self):
self.compress["table_config"][1],
self.compress["table_config"][2],
self.compress["table_config"][3],
suffix=suffix,
)
# for fparam or aparam settings in 'ener' type fitting net
self.fitting.init_variables(graph, graph_def)
if self.typeebd is not None:
self.typeebd.init_variables(graph, graph_def)
self.fitting.init_variables(graph, graph_def, suffix=suffix)
if (
self.typeebd is not None
and self.typeebd.type_embedding_net_variables is None
):
self.typeebd.init_variables(graph, graph_def, suffix=suffix)

def get_fitting(self) -> Union[Fitting, dict]:
"""Get the fitting(s)."""
Expand Down
Loading