Skip to content

Commit

Permalink
keep old behaviors
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz committed Mar 21, 2024
1 parent 89e640c commit 2a61e3c
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 0 deletions.
2 changes: 2 additions & 0 deletions deepmd/tf/fit/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ def build(
nframes = tf.shape(inputs)[0]

if self.mixed_types or type_embedding is not None:
# keep old behavior
self.mixed_types = True
nloc_mask = tf.reshape(
tf.tile(tf.repeat(self.sel_mask, natoms[2:]), [nframes]), [nframes, -1]
)
Expand Down
2 changes: 2 additions & 0 deletions deepmd/tf/fit/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,8 @@ def build(
tf.slice(atype_nall, [0, 0], [-1, natoms[0]]), [-1]
) ## lammps will make error
if type_embedding is not None:
# keep old behavior
self.mixed_types = True

Check warning on line 513 in deepmd/tf/fit/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/fit/dos.py#L513

Added line #L513 was not covered by tests
atype_embed = tf.nn.embedding_lookup(type_embedding, self.atype_nloc)
else:
atype_embed = None
Expand Down
2 changes: 2 additions & 0 deletions deepmd/tf/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,8 @@ def build(
):
type_embedding = nvnmd_cfg.map["t_ebd"]
if type_embedding is not None:
# keep old behavior
self.mixed_types = True
atype_embed = tf.nn.embedding_lookup(type_embedding, self.atype_nloc)
else:
atype_embed = None
Expand Down
2 changes: 2 additions & 0 deletions deepmd/tf/fit/polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,8 @@ def build(
nframes = tf.shape(inputs)[0]

if self.mixed_types or type_embedding is not None:
# keep old behavior
self.mixed_types = True
# nframes x nloc
nloc_mask = tf.reshape(
tf.tile(tf.repeat(self.sel_mask, natoms[2:]), [nframes]), [nframes, -1]
Expand Down

0 comments on commit 2a61e3c

Please sign in to comment.