Skip to content

Commit

Permalink
refactor: remove global data_requirements (#3798)
Browse files Browse the repository at this point in the history
Fix #3522. Fix #3540.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced a new property `input_requirement` across various models to
streamline data requirements handling.
- Added a new property `label_requirement` for loss classes to specify
data label requirements.

- **Bug Fixes**
- Refactored and removed outdated data requirement handling methods to
improve data processing efficiency.

- **Tests**
- Added new test cases to validate the `input_requirement` and
`label_requirement` properties.
- Updated existing tests to align with the new data requirements
handling approach.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored May 23, 2024
1 parent 8a359db commit 02e4ce9
Show file tree
Hide file tree
Showing 30 changed files with 448 additions and 211 deletions.
61 changes: 0 additions & 61 deletions deepmd/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
Any,
Dict,
List,
Optional,
Set,
TypeVar,
Union,
Expand All @@ -39,8 +38,6 @@
)

__all__ = [
"data_requirement",
"add_data_requirement",
"select_idx_map",
"make_default_mesh",
"j_must_have",
Expand Down Expand Up @@ -78,64 +75,6 @@
)


# TODO: refactor data_requirement to make it not a global variable
# this is not a good way to do things. This is some global variable to which
# anyone can write and there is no good way to keep track of the changes
data_requirement = {}


def add_data_requirement(
key: str,
ndof: int,
atomic: bool = False,
must: bool = False,
high_prec: bool = False,
type_sel: Optional[bool] = None,
repeat: int = 1,
default: float = 0.0,
dtype: Optional[np.dtype] = None,
output_natoms_for_type_sel: bool = False,
):
"""Specify data requirements for training.
Parameters
----------
key : str
type of data stored in corresponding `*.npy` file e.g. `forces` or `energy`
ndof : int
number of the degrees of freedom, this is tied to `atomic` parameter e.g. forces
have `atomic=True` and `ndof=3`
atomic : bool, optional
specifies whwther the `ndof` keyworrd applies to per atom quantity or not,
by default False
must : bool, optional
specifi if the `*.npy` data file must exist, by default False
high_prec : bool, optional
if true load data to `np.float64` else `np.float32`, by default False
type_sel : bool, optional
select only certain type of atoms, by default None
repeat : int, optional
if specify repaeat data `repeat` times, by default 1
default : float, optional, default=0.
default value of data
dtype : np.dtype, optional
the dtype of data, overwrites `high_prec` if provided
output_natoms_for_type_sel : bool, optional
if True and type_sel is True, the atomic dimension will be natoms instead of nsel
"""
data_requirement[key] = {
"ndof": ndof,
"atomic": atomic,
"must": must,
"high_prec": high_prec,
"type_sel": type_sel,
"repeat": repeat,
"default": default,
"dtype": dtype,
"output_natoms_for_type_sel": output_natoms_for_type_sel,
}


def select_idx_map(atom_types: np.ndarray, select_types: np.ndarray) -> np.ndarray:
"""Build map of indices for element supplied element types from all atoms list.
Expand Down
6 changes: 0 additions & 6 deletions deepmd/tf/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
from deepmd.common import (
VALID_ACTIVATION,
VALID_PRECISION,
add_data_requirement,
data_requirement,
expand_sys_str,
get_np_precision,
j_loader,
Expand All @@ -47,8 +45,6 @@

__all__ = [
# from deepmd.common
"data_requirement",
"add_data_requirement",
"select_idx_map",
"make_default_mesh",
"j_must_have",
Expand Down Expand Up @@ -291,8 +287,6 @@ def wrapper(self, *args, **kwargs):
def clear_session():
"""Reset all state generated by DeePMD-kit."""
tf.reset_default_graph()
# TODO: remove this line when data_requirement is not a global variable
data_requirement.clear()
_TF_VERSION = Version(TF_VERSION)
if _TF_VERSION < Version("2.4.0"):
tf.train.experimental.disable_mixed_precision_graph_rewrite()
Expand Down
8 changes: 8 additions & 0 deletions deepmd/tf/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from deepmd.tf.utils import (
PluginVariant,
)
from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.utils.plugin import (
make_plugin_registry,
)
Expand Down Expand Up @@ -512,3 +515,8 @@ def serialize(self, suffix: str = "") -> dict:
Name suffix to identify this descriptor
"""
raise NotImplementedError(f"Not implemented in class {self.__name__}")

@property
def input_requirement(self) -> List[DataRequirementItem]:
"""Return data requirements needed for the model input."""
return []
20 changes: 15 additions & 5 deletions deepmd/tf/descriptor/se_a_ebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@

import numpy as np

from deepmd.tf.common import (
add_data_requirement,
)
from deepmd.tf.env import (
GLOBAL_TF_FLOAT_PRECISION,
op_module,
Expand All @@ -18,6 +15,9 @@
embedding_net,
one_layer,
)
from deepmd.utils.data import (
DataRequirementItem,
)

from .descriptor import (
Descriptor,
Expand Down Expand Up @@ -110,8 +110,6 @@ def __init__(
self.type_nlayer = type_nlayer
self.type_one_side = type_one_side
self.numb_aparam = numb_aparam
if self.numb_aparam > 0:
add_data_requirement("aparam", 3, atomic=True, must=True, high_prec=False)

def build(
self,
Expand Down Expand Up @@ -600,3 +598,15 @@ def _ebd_filter(
result = tf.reshape(result, [-1, outputs_size_2 * outputs_size])

return result, qmat

@property
def input_requirement(self) -> List[DataRequirementItem]:
"""Return data requirements needed for the model input."""
data_requirement = super().input_requirement
if self.numb_aparam > 0:
data_requirement.append(
DataRequirementItem(
"aparam", 3, atomic=True, must=True, high_prec=False
)
)
return data_requirement
17 changes: 12 additions & 5 deletions deepmd/tf/descriptor/se_a_ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@

import numpy as np

from deepmd.tf.common import (
add_data_requirement,
)
from deepmd.tf.env import (
GLOBAL_NP_FLOAT_PRECISION,
GLOBAL_TF_FLOAT_PRECISION,
Expand All @@ -20,6 +17,9 @@
from deepmd.tf.utils.sess import (
run_sess,
)
from deepmd.utils.data import (
DataRequirementItem,
)

from .descriptor import (
Descriptor,
Expand Down Expand Up @@ -361,8 +361,6 @@ def __init__(
self.dstd = None
self.davg = None

add_data_requirement("efield", 3, atomic=True, must=True, high_prec=False)

self.place_holders = {}
avg_zero = np.zeros([self.ntypes, self.ndescrpt]).astype(
GLOBAL_NP_FLOAT_PRECISION
Expand Down Expand Up @@ -586,3 +584,12 @@ def _compute_dstats_sys_smth(
sysr2.append(sumr2)
sysa2.append(suma2)
return sysr, sysr2, sysa, sysa2, sysn

@property
def input_requirement(self) -> List[DataRequirementItem]:
"""Return data requirements needed for the model input."""
data_requirement = super().input_requirement
data_requirement.append(
DataRequirementItem("efield", 3, atomic=True, must=True, high_prec=False)
)
return data_requirement
2 changes: 2 additions & 0 deletions deepmd/tf/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def _do_work(jdata: Dict[str, Any], run_opt: RunOptions, is_compress: bool = Fal
train_data = get_data(
jdata["training"]["training_data"], rcut, ipt_type_map, modifier
)
train_data.add_data_requirements(model.data_requirements)
train_data.print_summary("training")
if jdata["training"].get("validation_data", None) is not None:
valid_data = get_data(
Expand All @@ -203,6 +204,7 @@ def _do_work(jdata: Dict[str, Any], run_opt: RunOptions, is_compress: bool = Fal
train_data.type_map,
modifier,
)
valid_data.add_data_requirements(model.data_requirements)
valid_data.print_summary("validation")
else:
if modifier is not None:
Expand Down
31 changes: 21 additions & 10 deletions deepmd/tf/fit/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import numpy as np

from deepmd.tf.common import (
add_data_requirement,
cast_precision,
get_activation_func,
get_precision,
Expand Down Expand Up @@ -43,6 +42,9 @@
from deepmd.tf.utils.network import (
one_layer_rand_seed_shift,
)
from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.utils.out_stat import (
compute_stats_from_redu,
)
Expand Down Expand Up @@ -151,18 +153,9 @@ def __init__(

self.useBN = False
self.bias_dos = np.zeros((self.ntypes, self.numb_dos), dtype=np.float64)
# data requirement
if self.numb_fparam > 0:
add_data_requirement(
"fparam", self.numb_fparam, atomic=False, must=True, high_prec=False
)
self.fparam_avg = None
self.fparam_std = None
self.fparam_inv_std = None
if self.numb_aparam > 0:
add_data_requirement(
"aparam", self.numb_aparam, atomic=True, must=True, high_prec=False
)
self.aparam_avg = None
self.aparam_std = None
self.aparam_inv_std = None
Expand Down Expand Up @@ -738,3 +731,21 @@ def serialize(self, suffix: str = "") -> dict:
},
}
return data

@property
def input_requirement(self) -> List[DataRequirementItem]:
"""Return data requirements needed for the model input."""
data_requirement = []
if self.numb_fparam > 0:
data_requirement.append(
DataRequirementItem(
"fparam", self.numb_fparam, atomic=False, must=True, high_prec=False
)
)
if self.numb_aparam > 0:
data_requirement.append(
DataRequirementItem(
"aparam", self.numb_aparam, atomic=True, must=True, high_prec=False
)
)
return data_requirement
31 changes: 21 additions & 10 deletions deepmd/tf/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import numpy as np

from deepmd.tf.common import (
add_data_requirement,
cast_precision,
get_activation_func,
get_precision,
Expand Down Expand Up @@ -53,6 +52,9 @@
from deepmd.tf.utils.spin import (
Spin,
)
from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.utils.finetune import (
change_energy_bias_lower,
)
Expand Down Expand Up @@ -218,18 +220,9 @@ def __init__(
self.atom_ener.append(None)
self.useBN = False
self.bias_atom_e = np.zeros(self.ntypes, dtype=np.float64)
# data requirement
if self.numb_fparam > 0:
add_data_requirement(
"fparam", self.numb_fparam, atomic=False, must=True, high_prec=False
)
self.fparam_avg = None
self.fparam_std = None
self.fparam_inv_std = None
if self.numb_aparam > 0:
add_data_requirement(
"aparam", self.numb_aparam, atomic=True, must=True, high_prec=False
)
self.aparam_avg = None
self.aparam_std = None
self.aparam_inv_std = None
Expand Down Expand Up @@ -939,3 +932,21 @@ def serialize(self, suffix: str = "") -> dict:
},
}
return data

@property
def input_requirement(self) -> List[DataRequirementItem]:
"""Return data requirements needed for the model input."""
data_requirement = []
if self.numb_fparam > 0:
data_requirement.append(
DataRequirementItem(
"fparam", self.numb_fparam, atomic=False, must=True, high_prec=False
)
)
if self.numb_aparam > 0:
data_requirement.append(
DataRequirementItem(
"aparam", self.numb_aparam, atomic=True, must=True, high_prec=False
)
)
return data_requirement
8 changes: 8 additions & 0 deletions deepmd/tf/fit/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
from deepmd.tf.utils import (
PluginVariant,
)
from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.utils.plugin import (
make_plugin_registry,
)
Expand Down Expand Up @@ -252,3 +255,8 @@ def deserialize_network(cls, data: dict, suffix: str = "") -> dict:
# prevent keyError
fitting_net_variables[f"{layer_name}{key}{suffix}/idt"] = 0.0
return fitting_net_variables

@property
def input_requirement(self) -> List[DataRequirementItem]:
"""Return data requirements needed for the model input."""
return []
Loading

0 comments on commit 02e4ce9

Please sign in to comment.