diff --git a/deepmd/fit/ener.py b/deepmd/fit/ener.py index d1e2fb655f..ac207b4df9 100644 --- a/deepmd/fit/ener.py +++ b/deepmd/fit/ener.py @@ -1,6 +1,7 @@ import warnings import numpy as np from typing import Tuple, List +from packaging.version import Version from deepmd.env import tf from deepmd.common import add_data_requirement, get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter, cast_precision @@ -11,7 +12,7 @@ from deepmd.fit.fitting import Fitting from deepmd.env import global_cvt_2_tf_float -from deepmd.env import GLOBAL_TF_FLOAT_PRECISION +from deepmd.env import GLOBAL_TF_FLOAT_PRECISION, TF_VERSION class EnerFitting (Fitting): r"""Fitting the energy of the system. The force and the virial can also be trained. @@ -490,6 +491,11 @@ def build (self, bias_atom_e=0.0, suffix=suffix, reuse=reuse ) outs = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms[0]]) + # add atom energy bias; TF will broadcast to all batches + # tf.repeat is avaiable in TF>=2.1 or TF 1.15 + _TF_VERSION = Version(TF_VERSION) + if (Version('1.15') <= _TF_VERSION < Version('2') or _TF_VERSION >= Version('2.1')) and self.bias_atom_e is not None: + outs += tf.repeat(tf.constant(self.bias_atom_e, dtype=self.fitting_precision), natoms[2:]) if self.tot_ener_zero: force_tot_ener = 0.0