Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tf: add explict mixed_types argument to fittings #3583

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 31 additions & 22 deletions deepmd/tf/fit/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ class DipoleFittingSeA(Fitting):
The precision of the embedding net parameters. Supported options are |PRECISION|
uniform_seed
Only for the purpose of backward compatibility, retrieves the old behavior of using the random seed
mixed_types : bool
If true, use a uniform fitting net for all atom types, otherwise use
different fitting nets for different atom types.
"""

def __init__(
Expand All @@ -76,6 +79,7 @@ def __init__(
activation_function: str = "tanh",
precision: str = "default",
uniform_seed: bool = False,
mixed_types: bool = False,
**kwargs,
) -> None:
"""Constructor."""
Expand All @@ -100,6 +104,7 @@ def __init__(
self.useBN = False
self.fitting_net_variables = None
self.mixed_prec = None
self.mixed_types = mixed_types

def get_sel_type(self) -> int:
"""Get selected type."""
Expand All @@ -109,6 +114,7 @@ def get_out_size(self) -> int:
"""Get the output size. Should be 3."""
return 3

@cast_precision
def _build_lower(self, start_index, natoms, inputs, rot_mat, suffix="", reuse=None):
# cut-out inputs
inputs_i = tf.slice(inputs, [0, start_index, 0], [-1, natoms, -1])
Expand Down Expand Up @@ -172,7 +178,6 @@ def _build_lower(self, start_index, natoms, inputs, rot_mat, suffix="", reuse=No
final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms, 3])
return final_layer

@cast_precision
def build(
self,
input_d: tf.Tensor,
Expand Down Expand Up @@ -215,8 +220,12 @@ def build(
start_index = 0
inputs = tf.reshape(input_d, [-1, natoms[0], self.dim_descrpt])
rot_mat = tf.reshape(rot_mat, [-1, natoms[0], self.dim_rot_mat])
if nframes is None:
nframes = tf.shape(inputs)[0]

if type_embedding is not None:
if self.mixed_types or type_embedding is not None:
# keep old behavior
self.mixed_types = True
nloc_mask = tf.reshape(
tf.tile(tf.repeat(self.sel_mask, natoms[2:]), [nframes]), [nframes, -1]
)
Expand All @@ -228,13 +237,30 @@ def build(
self.nloc_masked = tf.shape(
tf.reshape(self.atype_nloc_masked, [nframes, -1])
)[1]

if type_embedding is not None:
atype_embed = tf.nn.embedding_lookup(type_embedding, self.atype_nloc_masked)
else:
atype_embed = None

self.atype_embed = atype_embed
if atype_embed is not None:
inputs = tf.reshape(
tf.reshape(inputs, [nframes, natoms[0], self.dim_descrpt])[nloc_mask],
[-1, self.dim_descrpt],
)
rot_mat = tf.reshape(
tf.reshape(rot_mat, [nframes, natoms[0], self.dim_rot_mat_1 * 3])[
nloc_mask
],
[-1, self.dim_rot_mat_1, 3],
)
atype_embed = tf.cast(atype_embed, self.fitting_precision)
type_shape = atype_embed.get_shape().as_list()
inputs = tf.concat([inputs, atype_embed], axis=1)
self.dim_descrpt = self.dim_descrpt + type_shape[1]

if atype_embed is None:
if not self.mixed_types:
count = 0
outs_list = []
for type_i in range(self.ntypes):
Expand All @@ -255,20 +281,6 @@ def build(
count += 1
outs = tf.concat(outs_list, axis=1)
else:
inputs = tf.reshape(
tf.reshape(inputs, [nframes, natoms[0], self.dim_descrpt])[nloc_mask],
[-1, self.dim_descrpt],
)
rot_mat = tf.reshape(
tf.reshape(rot_mat, [nframes, natoms[0], self.dim_rot_mat_1 * 3])[
nloc_mask
],
[-1, self.dim_rot_mat_1, 3],
)
atype_embed = tf.cast(atype_embed, self.fitting_precision)
type_shape = atype_embed.get_shape().as_list()
inputs = tf.concat([inputs, atype_embed], axis=1)
self.dim_descrpt = self.dim_descrpt + type_shape[1]
inputs = tf.reshape(inputs, [nframes, self.nloc_masked, self.dim_descrpt])
rot_mat = tf.reshape(
rot_mat, [nframes, self.nloc_masked, self.dim_rot_mat_1 * 3]
Expand Down Expand Up @@ -354,9 +366,7 @@ def serialize(self, suffix: str) -> dict:
"ntypes": self.ntypes,
"dim_descrpt": self.dim_descrpt,
"embedding_width": self.dim_rot_mat_1,
# very bad design: type embedding is not passed to the class
# TODO: refactor the class for type embedding and dipole fitting
"mixed_types": False,
"mixed_types": self.mixed_types,
"dim_out": 3,
"neuron": self.n_neuron,
"resnet_dt": self.resnet_dt,
Expand All @@ -365,8 +375,7 @@ def serialize(self, suffix: str) -> dict:
"exclude_types": [],
"nets": self.serialize_network(
ntypes=self.ntypes,
# TODO: consider type embeddings in dipole fitting
ndim=1,
ndim=0 if self.mixed_types else 1,
in_dim=self.dim_descrpt,
out_dim=self.dim_rot_mat_1,
neuron=self.n_neuron,
Expand Down
30 changes: 17 additions & 13 deletions deepmd/tf/fit/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@
use_aparam_as_mask: bool, optional
If True, the atomic parameters will be used as a mask that determines the atom is real/virtual.
And the aparam will not be used as the atomic parameters for embedding.
mixed_types : bool
If true, use a uniform fitting net for all atom types, otherwise use
different fitting nets for different atom types.
"""

def __init__(
Expand All @@ -114,6 +117,7 @@
uniform_seed: bool = False,
layer_name: Optional[List[Optional[str]]] = None,
use_aparam_as_mask: bool = False,
mixed_types: bool = False,
**kwargs,
) -> None:
"""Constructor."""
Expand Down Expand Up @@ -171,6 +175,7 @@
assert (
len(self.layer_name) == len(self.n_neuron) + 1
), "length of layer_name should be that of n_neuron + 1"
self.mixed_types = mixed_types

def get_numb_fparam(self) -> int:
"""Get the number of frame parameters."""
Expand Down Expand Up @@ -504,13 +509,22 @@
tf.slice(atype_nall, [0, 0], [-1, natoms[0]]), [-1]
) ## lammps will make error
if type_embedding is not None:
# keep old behavior
self.mixed_types = True

Check warning on line 513 in deepmd/tf/fit/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/fit/dos.py#L513

Added line #L513 was not covered by tests
atype_embed = tf.nn.embedding_lookup(type_embedding, self.atype_nloc)
else:
atype_embed = None

self.atype_embed = atype_embed
if atype_embed is not None:
atype_embed = tf.cast(atype_embed, GLOBAL_TF_FLOAT_PRECISION)
type_shape = atype_embed.get_shape().as_list()
inputs = tf.concat(

Check warning on line 522 in deepmd/tf/fit/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/fit/dos.py#L520-L522

Added lines #L520 - L522 were not covered by tests
[tf.reshape(inputs, [-1, self.dim_descrpt]), atype_embed], axis=1
)
self.dim_descrpt = self.dim_descrpt + type_shape[1]

Check warning on line 525 in deepmd/tf/fit/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/fit/dos.py#L525

Added line #L525 was not covered by tests

if atype_embed is None:
if not self.mixed_types:
start_index = 0
outs_list = []
for type_i in range(self.ntypes):
Expand Down Expand Up @@ -541,13 +555,6 @@
outs = tf.concat(outs_list, axis=1)
# with type embedding
else:
atype_embed = tf.cast(atype_embed, GLOBAL_TF_FLOAT_PRECISION)
type_shape = atype_embed.get_shape().as_list()
inputs = tf.concat(
[tf.reshape(inputs, [-1, self.dim_descrpt]), atype_embed], axis=1
)
original_dim_descrpt = self.dim_descrpt
self.dim_descrpt = self.dim_descrpt + type_shape[1]
inputs = tf.reshape(inputs, [-1, natoms[0], self.dim_descrpt])
final_layer = self._build_lower(
0,
Expand Down Expand Up @@ -700,9 +707,7 @@
"var_name": "dos",
"ntypes": self.ntypes,
"dim_descrpt": self.dim_descrpt,
# very bad design: type embedding is not passed to the class
# TODO: refactor the class for DOSFitting and type embedding
"mixed_types": False,
"mixed_types": self.mixed_types,
"dim_out": self.numb_dos,
"neuron": self.n_neuron,
"resnet_dt": self.resnet_dt,
Expand All @@ -715,8 +720,7 @@
"exclude_types": [],
"nets": self.serialize_network(
ntypes=self.ntypes,
# TODO: consider type embeddings for DOSFitting
ndim=1,
ndim=0 if self.mixed_types else 1,
in_dim=self.dim_descrpt + self.numb_fparam + self.numb_aparam,
out_dim=self.numb_dos,
neuron=self.n_neuron,
Expand Down
43 changes: 26 additions & 17 deletions deepmd/tf/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ class EnerFitting(Fitting):
use_aparam_as_mask: bool, optional
If True, the atomic parameters will be used as a mask that determines the atom is real/virtual.
And the aparam will not be used as the atomic parameters for embedding.
mixed_types : bool
If true, use a uniform fitting net for all atom types, otherwise use
different fitting nets for different atom types.
"""

def __init__(
Expand All @@ -162,6 +165,7 @@ def __init__(
layer_name: Optional[List[Optional[str]]] = None,
use_aparam_as_mask: bool = False,
spin: Optional[Spin] = None,
mixed_types: bool = False,
**kwargs,
) -> None:
"""Constructor."""
Expand Down Expand Up @@ -238,6 +242,7 @@ def __init__(
assert (
len(self.layer_name) == len(self.n_neuron) + 1
), "length of layer_name should be that of n_neuron + 1"
self.mixed_types = mixed_types

def get_numb_fparam(self) -> int:
"""Get the number of frame parameters."""
Expand Down Expand Up @@ -585,6 +590,8 @@ def build(
)
else:
inputs_zero = tf.zeros_like(inputs, dtype=GLOBAL_TF_FLOAT_PRECISION)
else:
inputs_zero = None

if bias_atom_e is not None:
assert len(bias_atom_e) == self.ntypes
Expand Down Expand Up @@ -628,13 +635,29 @@ def build(
):
type_embedding = nvnmd_cfg.map["t_ebd"]
if type_embedding is not None:
# keep old behavior
self.mixed_types = True
atype_embed = tf.nn.embedding_lookup(type_embedding, self.atype_nloc)
else:
atype_embed = None

self.atype_embed = atype_embed
original_dim_descrpt = self.dim_descrpt
if atype_embed is not None:
atype_embed = tf.cast(atype_embed, GLOBAL_TF_FLOAT_PRECISION)
type_shape = atype_embed.get_shape().as_list()
inputs = tf.concat(
[tf.reshape(inputs, [-1, self.dim_descrpt]), atype_embed], axis=1
)
self.dim_descrpt = self.dim_descrpt + type_shape[1]
if len(self.atom_ener):
assert inputs_zero is not None
inputs_zero = tf.concat(
[tf.reshape(inputs_zero, [-1, original_dim_descrpt]), atype_embed],
Fixed Show fixed Hide fixed
axis=1,
)

if atype_embed is None:
if not self.mixed_types:
start_index = 0
outs_list = []
for type_i in range(ntypes_atom):
Expand Down Expand Up @@ -673,13 +696,6 @@ def build(
outs = tf.concat(outs_list, axis=1)
# with type embedding
else:
atype_embed = tf.cast(atype_embed, GLOBAL_TF_FLOAT_PRECISION)
type_shape = atype_embed.get_shape().as_list()
inputs = tf.concat(
[tf.reshape(inputs, [-1, self.dim_descrpt]), atype_embed], axis=1
)
original_dim_descrpt = self.dim_descrpt
self.dim_descrpt = self.dim_descrpt + type_shape[1]
inputs = tf.reshape(inputs, [-1, natoms[0], self.dim_descrpt])
final_layer = self._build_lower(
0,
Expand All @@ -693,10 +709,6 @@ def build(
)
if len(self.atom_ener):
# remove contribution in vacuum
inputs_zero = tf.concat(
[tf.reshape(inputs_zero, [-1, original_dim_descrpt]), atype_embed],
axis=1,
)
inputs_zero = tf.reshape(inputs_zero, [-1, natoms[0], self.dim_descrpt])
zero_layer = self._build_lower(
0,
Expand Down Expand Up @@ -892,9 +904,7 @@ def serialize(self, suffix: str = "") -> dict:
"var_name": "energy",
"ntypes": self.ntypes,
"dim_descrpt": self.dim_descrpt,
# very bad design: type embedding is not passed to the class
# TODO: refactor the class for energy fitting and type embedding
"mixed_types": False,
"mixed_types": self.mixed_types,
"dim_out": 1,
"neuron": self.n_neuron,
"resnet_dt": self.resnet_dt,
Expand All @@ -912,8 +922,7 @@ def serialize(self, suffix: str = "") -> dict:
"exclude_types": [],
"nets": self.serialize_network(
ntypes=self.ntypes,
# TODO: consider type embeddings for type embedding
ndim=1,
ndim=0 if self.mixed_types else 1,
in_dim=self.dim_descrpt + self.numb_fparam + self.numb_aparam,
neuron=self.n_neuron,
activation_function=self.activation_function_name,
Expand Down
Loading