From ca2b004448254ffa92d0179164ca7872237af295 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 30 Oct 2024 16:47:28 -0400 Subject: [PATCH 1/2] fix(dpmodel/jax): fix fparam and aparam support in DeepEval Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/fitting/general_fitting.py | 4 +- deepmd/dpmodel/infer/deep_eval.py | 21 +++++++-- deepmd/jax/infer/deep_eval.py | 16 +++++-- deepmd/jax/utils/serialization.py | 8 ++-- source/tests/consistent/io/test_io.py | 56 +++++++++++++++++++++++ 5 files changed, 91 insertions(+), 14 deletions(-) diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index e55f57c774..58f8639cac 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -388,8 +388,8 @@ def _call_common( assert fparam is not None, "fparam should not be None" if fparam.shape[-1] != self.numb_fparam: raise ValueError( - "get an input fparam of dim {fparam.shape[-1]}, ", - "which is not consistent with {self.numb_fparam}.", + f"get an input fparam of dim {fparam.shape[-1]}, " + f"which is not consistent with {self.numb_fparam}." ) fparam = (fparam - self.fparam_avg) * self.fparam_inv_std fparam = xp.tile( diff --git a/deepmd/dpmodel/infer/deep_eval.py b/deepmd/dpmodel/infer/deep_eval.py index c1f3e4630b..5463743ada 100644 --- a/deepmd/dpmodel/infer/deep_eval.py +++ b/deepmd/dpmodel/infer/deep_eval.py @@ -204,8 +204,6 @@ def eval( The output of the evaluation. The keys are the names of the output variables, and the values are the corresponding output arrays. """ - if fparam is not None or aparam is not None: - raise NotImplementedError # convert all of the input to numpy array atom_types = np.array(atom_types, dtype=np.int32) coords = np.array(coords) @@ -216,7 +214,7 @@ def eval( ) request_defs = self._get_request_defs(atomic) out = self._eval_func(self._eval_model, numb_test, natoms)( - coords, cells, atom_types, request_defs + coords, cells, atom_types, fparam, aparam, request_defs ) return dict( zip( @@ -306,6 +304,8 @@ def _eval_model( coords: np.ndarray, cells: Optional[np.ndarray], atom_types: np.ndarray, + fparam: Optional[np.ndarray], + aparam: Optional[np.ndarray], request_defs: list[OutputVariableDef], ): model = self.dp @@ -323,12 +323,25 @@ def _eval_model( box_input = cells.reshape([-1, 3, 3]) else: box_input = None + if fparam is not None: + fparam_input = fparam.reshape(nframes, self.get_dim_fparam()) + else: + fparam_input = None + if aparam is not None: + aparam_input = aparam.reshape(nframes, natoms, self.get_dim_aparam()) + else: + aparam_input = None do_atomic_virial = any( x.category == OutputVariableCategory.DERV_C_REDU for x in request_defs ) batch_output = model( - coord_input, type_input, box=box_input, do_atomic_virial=do_atomic_virial + coord_input, + type_input, + box=box_input, + fparam=fparam_input, + aparam=aparam_input, + do_atomic_virial=do_atomic_virial, ) if isinstance(batch_output, tuple): batch_output = batch_output[0] diff --git a/deepmd/jax/infer/deep_eval.py b/deepmd/jax/infer/deep_eval.py index 76f044a327..c1967fb0da 100644 --- a/deepmd/jax/infer/deep_eval.py +++ b/deepmd/jax/infer/deep_eval.py @@ -214,8 +214,6 @@ def eval( The output of the evaluation. The keys are the names of the output variables, and the values are the corresponding output arrays. """ - if fparam is not None or aparam is not None: - raise NotImplementedError # convert all of the input to numpy array atom_types = np.array(atom_types, dtype=np.int32) coords = np.array(coords) @@ -226,7 +224,7 @@ def eval( ) request_defs = self._get_request_defs(atomic) out = self._eval_func(self._eval_model, numb_test, natoms)( - coords, cells, atom_types, request_defs + coords, cells, atom_types, fparam, aparam, request_defs ) return dict( zip( @@ -316,6 +314,8 @@ def _eval_model( coords: np.ndarray, cells: Optional[np.ndarray], atom_types: np.ndarray, + fparam: Optional[np.ndarray], + aparam: Optional[np.ndarray], request_defs: list[OutputVariableDef], ): model = self.dp @@ -333,6 +333,14 @@ def _eval_model( box_input = cells.reshape([-1, 3, 3]) else: box_input = None + if fparam is not None: + fparam_input = fparam.reshape(nframes, self.get_dim_fparam()) + else: + fparam_input = None + if aparam is not None: + aparam_input = aparam.reshape(nframes, natoms, self.get_dim_aparam()) + else: + aparam_input = None do_atomic_virial = any( x.category == OutputVariableCategory.DERV_C_REDU for x in request_defs @@ -341,6 +349,8 @@ def _eval_model( to_jax_array(coord_input), to_jax_array(type_input), box=to_jax_array(box_input), + fparam=to_jax_array(fparam_input), + aparam=to_jax_array(aparam_input), do_atomic_virial=do_atomic_virial, ) if isinstance(batch_output, tuple): diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index fcfcc8a610..a7d57523e2 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -51,18 +51,16 @@ def deserialize_to_file(model_file: str, data: dict) -> None: model_def_script = data["model_def_script"] call_lower = model.call_lower - nf, nloc, nghost, nfp, nap = jax_export.symbolic_shape( - "nf, nloc, nghost, nfp, nap" - ) + nf, nloc, nghost = jax_export.symbolic_shape("nf, nloc, nghost") exported = jax_export.export(jax.jit(call_lower))( jax.ShapeDtypeStruct((nf, nloc + nghost, 3), jnp.float64), # extended_coord jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int32), # extended_atype jax.ShapeDtypeStruct((nf, nloc, model.get_nnei()), jnp.int64), # nlist jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int64), # mapping - jax.ShapeDtypeStruct((nf, nfp), jnp.float64) + jax.ShapeDtypeStruct((nf, model.get_dim_fparam()), jnp.float64) if model.get_dim_fparam() else None, # fparam - jax.ShapeDtypeStruct((nf, nap), jnp.float64) + jax.ShapeDtypeStruct((nf, nloc, model.get_dim_aparam()), jnp.float64) if model.get_dim_aparam() else None, # aparam False, # do_atomic_virial diff --git a/source/tests/consistent/io/test_io.py b/source/tests/consistent/io/test_io.py index dc0f280d56..af26c41694 100644 --- a/source/tests/consistent/io/test_io.py +++ b/source/tests/consistent/io/test_io.py @@ -136,6 +136,8 @@ def test_deep_eval(self): [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], dtype=GLOBAL_NP_FLOAT_PRECISION, ).reshape(1, 9) + natoms = self.atype.shape[1] + nframes = self.atype.shape[0] prefix = "test_consistent_io_" + self.__class__.__name__.lower() rets = [] for backend_name in ("tensorflow", "pytorch", "dpmodel", "jax"): @@ -145,10 +147,20 @@ def test_deep_eval(self): reference_data = copy.deepcopy(self.data) self.save_data_to_model(prefix + backend.suffixes[0], reference_data) deep_eval = DeepEval(prefix + backend.suffixes[0]) + if deep_eval.get_dim_fparam() > 0: + fparam = np.ones((nframes, deep_eval.get_dim_fparam())) + else: + fparam = None + if deep_eval.get_dim_aparam() > 0: + aparam = np.ones((nframes, natoms, deep_eval.get_dim_aparam())) + else: + aparam = None ret = deep_eval.eval( self.coords, self.box, self.atype, + fparam=fparam, + aparam=aparam, ) rets.append(ret) for ret in rets[1:]: @@ -199,3 +211,47 @@ def setUp(self): def tearDown(self): IOTest.tearDown(self) + + +class TestDeepPotFparamAparam(unittest.TestCase, IOTest): + def setUp(self): + model_def_script = { + "type_map": ["O", "H"], + "descriptor": { + "type": "se_e2_a", + "sel": [20, 20], + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [ + 3, + 6, + ], + "resnet_dt": False, + "axis_neuron": 2, + "precision": "float64", + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "type": "ener", + "neuron": [ + 5, + 5, + ], + "resnet_dt": True, + "precision": "float64", + "atom_ener": [], + "seed": 1, + "numb_fparam": 2, + "numb_aparam": 2, + }, + } + model = get_model(copy.deepcopy(model_def_script)) + self.data = { + "model": model.serialize(), + "backend": "test", + "model_def_script": model_def_script, + } + + def tearDown(self): + IOTest.tearDown(self) From 78e99805289956c695126bb08c322631c6d66747 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 30 Oct 2024 16:50:06 -0400 Subject: [PATCH 2/2] fix valueerror Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/fitting/general_fitting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index 58f8639cac..a027e1e59d 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -409,8 +409,8 @@ def _call_common( assert aparam is not None, "aparam should not be None" if aparam.shape[-1] != self.numb_aparam: raise ValueError( - "get an input aparam of dim {aparam.shape[-1]}, ", - "which is not consistent with {self.numb_aparam}.", + f"get an input aparam of dim {aparam.shape[-1]}, " + f"which is not consistent with {self.numb_aparam}." ) aparam = xp.reshape(aparam, [nf, nloc, self.numb_aparam]) aparam = (aparam - self.aparam_avg) * self.aparam_inv_std