Skip to content

Commit

Permalink
add atom energy bias to type embedding energy (#1592)
Browse files Browse the repository at this point in the history
* add atom energy bias to type embedding energy

Fix #684, where systems have different `atom_numb`. After this fix, RMSE should be quickly decreased in the very beginning.

* looks like tf.repeat is unavaiable in old TF...

* add statement `self.bias_atom_e is not None`
  • Loading branch information
njzjz authored Mar 25, 2022
1 parent 470f829 commit c17873d
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion deepmd/fit/ener.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings
import numpy as np
from typing import Tuple, List
from packaging.version import Version

from deepmd.env import tf
from deepmd.common import add_data_requirement, get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter, cast_precision
Expand All @@ -11,7 +12,7 @@
from deepmd.fit.fitting import Fitting

from deepmd.env import global_cvt_2_tf_float
from deepmd.env import GLOBAL_TF_FLOAT_PRECISION
from deepmd.env import GLOBAL_TF_FLOAT_PRECISION, TF_VERSION

class EnerFitting (Fitting):
r"""Fitting the energy of the system. The force and the virial can also be trained.
Expand Down Expand Up @@ -490,6 +491,11 @@ def build (self,
bias_atom_e=0.0, suffix=suffix, reuse=reuse
)
outs = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms[0]])
# add atom energy bias; TF will broadcast to all batches
# tf.repeat is avaiable in TF>=2.1 or TF 1.15
_TF_VERSION = Version(TF_VERSION)
if (Version('1.15') <= _TF_VERSION < Version('2') or _TF_VERSION >= Version('2.1')) and self.bias_atom_e is not None:
outs += tf.repeat(tf.constant(self.bias_atom_e, dtype=self.fitting_precision), natoms[2:])

if self.tot_ener_zero:
force_tot_ener = 0.0
Expand Down

0 comments on commit c17873d

Please sign in to comment.