Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(dpmodel/jax): fix fparam and aparam support in DeepEval #4285

Merged
merged 2 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,8 @@ def _call_common(
assert fparam is not None, "fparam should not be None"
if fparam.shape[-1] != self.numb_fparam:
raise ValueError(
"get an input fparam of dim {fparam.shape[-1]}, ",
"which is not consistent with {self.numb_fparam}.",
f"get an input fparam of dim {fparam.shape[-1]}, "
f"which is not consistent with {self.numb_fparam}."
)
fparam = (fparam - self.fparam_avg) * self.fparam_inv_std
fparam = xp.tile(
Expand All @@ -409,8 +409,8 @@ def _call_common(
assert aparam is not None, "aparam should not be None"
if aparam.shape[-1] != self.numb_aparam:
raise ValueError(
"get an input aparam of dim {aparam.shape[-1]}, ",
"which is not consistent with {self.numb_aparam}.",
f"get an input aparam of dim {aparam.shape[-1]}, "
f"which is not consistent with {self.numb_aparam}."
)
aparam = xp.reshape(aparam, [nf, nloc, self.numb_aparam])
aparam = (aparam - self.aparam_avg) * self.aparam_inv_std
Expand Down
21 changes: 17 additions & 4 deletions deepmd/dpmodel/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,6 @@ def eval(
The output of the evaluation. The keys are the names of the output
variables, and the values are the corresponding output arrays.
"""
if fparam is not None or aparam is not None:
raise NotImplementedError
# convert all of the input to numpy array
atom_types = np.array(atom_types, dtype=np.int32)
coords = np.array(coords)
Expand All @@ -216,7 +214,7 @@ def eval(
)
request_defs = self._get_request_defs(atomic)
out = self._eval_func(self._eval_model, numb_test, natoms)(
coords, cells, atom_types, request_defs
coords, cells, atom_types, fparam, aparam, request_defs
)
return dict(
zip(
Expand Down Expand Up @@ -306,6 +304,8 @@ def _eval_model(
coords: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
fparam: Optional[np.ndarray],
aparam: Optional[np.ndarray],
request_defs: list[OutputVariableDef],
):
model = self.dp
Expand All @@ -323,12 +323,25 @@ def _eval_model(
box_input = cells.reshape([-1, 3, 3])
else:
box_input = None
if fparam is not None:
fparam_input = fparam.reshape(nframes, self.get_dim_fparam())
else:
fparam_input = None
if aparam is not None:
aparam_input = aparam.reshape(nframes, natoms, self.get_dim_aparam())
else:
aparam_input = None

do_atomic_virial = any(
x.category == OutputVariableCategory.DERV_C_REDU for x in request_defs
)
batch_output = model(
coord_input, type_input, box=box_input, do_atomic_virial=do_atomic_virial
coord_input,
type_input,
box=box_input,
fparam=fparam_input,
aparam=aparam_input,
do_atomic_virial=do_atomic_virial,
)
if isinstance(batch_output, tuple):
batch_output = batch_output[0]
Expand Down
16 changes: 13 additions & 3 deletions deepmd/jax/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,6 @@ def eval(
The output of the evaluation. The keys are the names of the output
variables, and the values are the corresponding output arrays.
"""
if fparam is not None or aparam is not None:
raise NotImplementedError
# convert all of the input to numpy array
atom_types = np.array(atom_types, dtype=np.int32)
coords = np.array(coords)
Expand All @@ -226,7 +224,7 @@ def eval(
)
request_defs = self._get_request_defs(atomic)
out = self._eval_func(self._eval_model, numb_test, natoms)(
coords, cells, atom_types, request_defs
coords, cells, atom_types, fparam, aparam, request_defs
)
return dict(
zip(
Expand Down Expand Up @@ -316,6 +314,8 @@ def _eval_model(
coords: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
fparam: Optional[np.ndarray],
aparam: Optional[np.ndarray],
request_defs: list[OutputVariableDef],
):
model = self.dp
Expand All @@ -333,6 +333,14 @@ def _eval_model(
box_input = cells.reshape([-1, 3, 3])
else:
box_input = None
if fparam is not None:
fparam_input = fparam.reshape(nframes, self.get_dim_fparam())
else:
fparam_input = None
if aparam is not None:
aparam_input = aparam.reshape(nframes, natoms, self.get_dim_aparam())
else:
aparam_input = None

do_atomic_virial = any(
x.category == OutputVariableCategory.DERV_C_REDU for x in request_defs
Expand All @@ -341,6 +349,8 @@ def _eval_model(
to_jax_array(coord_input),
to_jax_array(type_input),
box=to_jax_array(box_input),
fparam=to_jax_array(fparam_input),
aparam=to_jax_array(aparam_input),
do_atomic_virial=do_atomic_virial,
)
if isinstance(batch_output, tuple):
Expand Down
8 changes: 3 additions & 5 deletions deepmd/jax/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,16 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
model_def_script = data["model_def_script"]
call_lower = model.call_lower

nf, nloc, nghost, nfp, nap = jax_export.symbolic_shape(
"nf, nloc, nghost, nfp, nap"
)
nf, nloc, nghost = jax_export.symbolic_shape("nf, nloc, nghost")
exported = jax_export.export(jax.jit(call_lower))(
jax.ShapeDtypeStruct((nf, nloc + nghost, 3), jnp.float64), # extended_coord
jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int32), # extended_atype
jax.ShapeDtypeStruct((nf, nloc, model.get_nnei()), jnp.int64), # nlist
jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int64), # mapping
jax.ShapeDtypeStruct((nf, nfp), jnp.float64)
jax.ShapeDtypeStruct((nf, model.get_dim_fparam()), jnp.float64)
if model.get_dim_fparam()
else None, # fparam
jax.ShapeDtypeStruct((nf, nap), jnp.float64)
jax.ShapeDtypeStruct((nf, nloc, model.get_dim_aparam()), jnp.float64)
if model.get_dim_aparam()
else None, # aparam
False, # do_atomic_virial
Expand Down
56 changes: 56 additions & 0 deletions source/tests/consistent/io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ def test_deep_eval(self):
[13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0],
dtype=GLOBAL_NP_FLOAT_PRECISION,
).reshape(1, 9)
natoms = self.atype.shape[1]
nframes = self.atype.shape[0]
prefix = "test_consistent_io_" + self.__class__.__name__.lower()
rets = []
for backend_name in ("tensorflow", "pytorch", "dpmodel", "jax"):
Expand All @@ -145,10 +147,20 @@ def test_deep_eval(self):
reference_data = copy.deepcopy(self.data)
self.save_data_to_model(prefix + backend.suffixes[0], reference_data)
deep_eval = DeepEval(prefix + backend.suffixes[0])
if deep_eval.get_dim_fparam() > 0:
fparam = np.ones((nframes, deep_eval.get_dim_fparam()))
else:
fparam = None
if deep_eval.get_dim_aparam() > 0:
aparam = np.ones((nframes, natoms, deep_eval.get_dim_aparam()))
else:
aparam = None
ret = deep_eval.eval(
self.coords,
self.box,
self.atype,
fparam=fparam,
aparam=aparam,
)
rets.append(ret)
for ret in rets[1:]:
Expand Down Expand Up @@ -199,3 +211,47 @@ def setUp(self):

def tearDown(self):
IOTest.tearDown(self)


class TestDeepPotFparamAparam(unittest.TestCase, IOTest):
def setUp(self):
model_def_script = {
"type_map": ["O", "H"],
"descriptor": {
"type": "se_e2_a",
"sel": [20, 20],
"rcut_smth": 0.50,
"rcut": 6.00,
"neuron": [
3,
6,
],
"resnet_dt": False,
"axis_neuron": 2,
"precision": "float64",
"type_one_side": True,
"seed": 1,
},
"fitting_net": {
"type": "ener",
"neuron": [
5,
5,
],
"resnet_dt": True,
"precision": "float64",
"atom_ener": [],
"seed": 1,
"numb_fparam": 2,
"numb_aparam": 2,
},
}
model = get_model(copy.deepcopy(model_def_script))
self.data = {
"model": model.serialize(),
"backend": "test",
"model_def_script": model_def_script,
}
njzjz marked this conversation as resolved.
Show resolved Hide resolved

def tearDown(self):
IOTest.tearDown(self)