Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: remove global data_requirements #3798

Merged
merged 8 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 0 additions & 61 deletions deepmd/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
Any,
Dict,
List,
Optional,
Set,
TypeVar,
Union,
Expand All @@ -39,8 +38,6 @@
)

__all__ = [
"data_requirement",
"add_data_requirement",
"select_idx_map",
"make_default_mesh",
"j_must_have",
Expand Down Expand Up @@ -78,64 +75,6 @@
)


# TODO: refactor data_requirement to make it not a global variable
# this is not a good way to do things. This is some global variable to which
# anyone can write and there is no good way to keep track of the changes
data_requirement = {}


def add_data_requirement(
key: str,
ndof: int,
atomic: bool = False,
must: bool = False,
high_prec: bool = False,
type_sel: Optional[bool] = None,
repeat: int = 1,
default: float = 0.0,
dtype: Optional[np.dtype] = None,
output_natoms_for_type_sel: bool = False,
):
"""Specify data requirements for training.

Parameters
----------
key : str
type of data stored in corresponding `*.npy` file e.g. `forces` or `energy`
ndof : int
number of the degrees of freedom, this is tied to `atomic` parameter e.g. forces
have `atomic=True` and `ndof=3`
atomic : bool, optional
specifies whwther the `ndof` keyworrd applies to per atom quantity or not,
by default False
must : bool, optional
specifi if the `*.npy` data file must exist, by default False
high_prec : bool, optional
if true load data to `np.float64` else `np.float32`, by default False
type_sel : bool, optional
select only certain type of atoms, by default None
repeat : int, optional
if specify repaeat data `repeat` times, by default 1
default : float, optional, default=0.
default value of data
dtype : np.dtype, optional
the dtype of data, overwrites `high_prec` if provided
output_natoms_for_type_sel : bool, optional
if True and type_sel is True, the atomic dimension will be natoms instead of nsel
"""
data_requirement[key] = {
"ndof": ndof,
"atomic": atomic,
"must": must,
"high_prec": high_prec,
"type_sel": type_sel,
"repeat": repeat,
"default": default,
"dtype": dtype,
"output_natoms_for_type_sel": output_natoms_for_type_sel,
}


def select_idx_map(atom_types: np.ndarray, select_types: np.ndarray) -> np.ndarray:
"""Build map of indices for element supplied element types from all atoms list.

Expand Down
6 changes: 0 additions & 6 deletions deepmd/tf/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
from deepmd.common import (
VALID_ACTIVATION,
VALID_PRECISION,
add_data_requirement,
data_requirement,
expand_sys_str,
get_np_precision,
j_loader,
Expand All @@ -47,8 +45,6 @@

__all__ = [
# from deepmd.common
"data_requirement",
"add_data_requirement",
"select_idx_map",
"make_default_mesh",
"j_must_have",
Expand Down Expand Up @@ -291,8 +287,6 @@ def wrapper(self, *args, **kwargs):
def clear_session():
"""Reset all state generated by DeePMD-kit."""
tf.reset_default_graph()
# TODO: remove this line when data_requirement is not a global variable
data_requirement.clear()
_TF_VERSION = Version(TF_VERSION)
if _TF_VERSION < Version("2.4.0"):
tf.train.experimental.disable_mixed_precision_graph_rewrite()
Expand Down
8 changes: 8 additions & 0 deletions deepmd/tf/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from deepmd.tf.utils import (
PluginVariant,
)
from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.utils.plugin import (
make_plugin_registry,
)
Expand Down Expand Up @@ -512,3 +515,8 @@
Name suffix to identify this descriptor
"""
raise NotImplementedError(f"Not implemented in class {self.__name__}")

@property
def input_requirement(self) -> List[DataRequirementItem]:
"""Return data requirements needed for the model input."""
return []

Check warning on line 522 in deepmd/tf/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/descriptor.py#L522

Added line #L522 was not covered by tests
20 changes: 15 additions & 5 deletions deepmd/tf/descriptor/se_a_ebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@

import numpy as np

from deepmd.tf.common import (
add_data_requirement,
)
from deepmd.tf.env import (
GLOBAL_TF_FLOAT_PRECISION,
op_module,
Expand All @@ -18,6 +15,9 @@
embedding_net,
one_layer,
)
from deepmd.utils.data import (
DataRequirementItem,
)

from .descriptor import (
Descriptor,
Expand Down Expand Up @@ -110,8 +110,6 @@
self.type_nlayer = type_nlayer
self.type_one_side = type_one_side
self.numb_aparam = numb_aparam
if self.numb_aparam > 0:
add_data_requirement("aparam", 3, atomic=True, must=True, high_prec=False)

def build(
self,
Expand Down Expand Up @@ -600,3 +598,15 @@
result = tf.reshape(result, [-1, outputs_size_2 * outputs_size])

return result, qmat

@property
def input_requirement(self) -> List[DataRequirementItem]:
"""Return data requirements needed for the model input."""
data_requirement = super().input_requirement
if self.numb_aparam > 0:
data_requirement.append(

Check warning on line 607 in deepmd/tf/descriptor/se_a_ebd.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_a_ebd.py#L605-L607

Added lines #L605 - L607 were not covered by tests
DataRequirementItem(
"aparam", 3, atomic=True, must=True, high_prec=False
)
)
return data_requirement

Check warning on line 612 in deepmd/tf/descriptor/se_a_ebd.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_a_ebd.py#L612

Added line #L612 was not covered by tests
njzjz marked this conversation as resolved.
Show resolved Hide resolved
17 changes: 12 additions & 5 deletions deepmd/tf/descriptor/se_a_ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@

import numpy as np

from deepmd.tf.common import (
add_data_requirement,
)
from deepmd.tf.env import (
GLOBAL_NP_FLOAT_PRECISION,
GLOBAL_TF_FLOAT_PRECISION,
Expand All @@ -20,6 +17,9 @@
from deepmd.tf.utils.sess import (
run_sess,
)
from deepmd.utils.data import (
DataRequirementItem,
)

from .descriptor import (
Descriptor,
Expand Down Expand Up @@ -361,8 +361,6 @@
self.dstd = None
self.davg = None

add_data_requirement("efield", 3, atomic=True, must=True, high_prec=False)

self.place_holders = {}
avg_zero = np.zeros([self.ntypes, self.ndescrpt]).astype(
GLOBAL_NP_FLOAT_PRECISION
Expand Down Expand Up @@ -586,3 +584,12 @@
sysr2.append(sumr2)
sysa2.append(suma2)
return sysr, sysr2, sysa, sysa2, sysn

@property
def input_requirement(self) -> List[DataRequirementItem]:
"""Return data requirements needed for the model input."""
data_requirement = super().input_requirement
data_requirement.append(

Check warning on line 592 in deepmd/tf/descriptor/se_a_ef.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_a_ef.py#L591-L592

Added lines #L591 - L592 were not covered by tests
DataRequirementItem("efield", 3, atomic=True, must=True, high_prec=False)
)
return data_requirement

Check warning on line 595 in deepmd/tf/descriptor/se_a_ef.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/descriptor/se_a_ef.py#L595

Added line #L595 was not covered by tests
2 changes: 2 additions & 0 deletions deepmd/tf/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@
train_data = get_data(
jdata["training"]["training_data"], rcut, ipt_type_map, modifier
)
train_data.add_data_requirements(model.data_requirements)

Check warning on line 198 in deepmd/tf/entrypoints/train.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/entrypoints/train.py#L198

Added line #L198 was not covered by tests
train_data.print_summary("training")
if jdata["training"].get("validation_data", None) is not None:
valid_data = get_data(
Expand All @@ -203,6 +204,7 @@
train_data.type_map,
modifier,
)
valid_data.add_data_requirements(model.data_requirements)

Check warning on line 207 in deepmd/tf/entrypoints/train.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/entrypoints/train.py#L207

Added line #L207 was not covered by tests
valid_data.print_summary("validation")
else:
if modifier is not None:
Expand Down
31 changes: 21 additions & 10 deletions deepmd/tf/fit/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import numpy as np

from deepmd.tf.common import (
add_data_requirement,
cast_precision,
get_activation_func,
get_precision,
Expand Down Expand Up @@ -43,6 +42,9 @@
from deepmd.tf.utils.network import (
one_layer_rand_seed_shift,
)
from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.utils.out_stat import (
compute_stats_from_redu,
)
Expand Down Expand Up @@ -151,18 +153,9 @@

self.useBN = False
self.bias_dos = np.zeros((self.ntypes, self.numb_dos), dtype=np.float64)
# data requirement
if self.numb_fparam > 0:
add_data_requirement(
"fparam", self.numb_fparam, atomic=False, must=True, high_prec=False
)
self.fparam_avg = None
self.fparam_std = None
self.fparam_inv_std = None
if self.numb_aparam > 0:
add_data_requirement(
"aparam", self.numb_aparam, atomic=True, must=True, high_prec=False
)
self.aparam_avg = None
self.aparam_std = None
self.aparam_inv_std = None
Expand Down Expand Up @@ -738,3 +731,21 @@
},
}
return data

@property
def input_requirement(self) -> List[DataRequirementItem]:
"""Return data requirements needed for the model input."""
data_requirement = []
if self.numb_fparam > 0:
data_requirement.append(

Check warning on line 740 in deepmd/tf/fit/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/fit/dos.py#L738-L740

Added lines #L738 - L740 were not covered by tests
DataRequirementItem(
"fparam", self.numb_fparam, atomic=False, must=True, high_prec=False
)
)
if self.numb_aparam > 0:
data_requirement.append(

Check warning on line 746 in deepmd/tf/fit/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/fit/dos.py#L745-L746

Added lines #L745 - L746 were not covered by tests
DataRequirementItem(
"aparam", self.numb_aparam, atomic=True, must=True, high_prec=False
)
)
return data_requirement

Check warning on line 751 in deepmd/tf/fit/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/fit/dos.py#L751

Added line #L751 was not covered by tests
njzjz marked this conversation as resolved.
Show resolved Hide resolved
31 changes: 21 additions & 10 deletions deepmd/tf/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import numpy as np

from deepmd.tf.common import (
add_data_requirement,
cast_precision,
get_activation_func,
get_precision,
Expand Down Expand Up @@ -53,6 +52,9 @@
from deepmd.tf.utils.spin import (
Spin,
)
from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.utils.finetune import (
change_energy_bias_lower,
)
Expand Down Expand Up @@ -218,18 +220,9 @@
self.atom_ener.append(None)
self.useBN = False
self.bias_atom_e = np.zeros(self.ntypes, dtype=np.float64)
# data requirement
if self.numb_fparam > 0:
add_data_requirement(
"fparam", self.numb_fparam, atomic=False, must=True, high_prec=False
)
self.fparam_avg = None
self.fparam_std = None
self.fparam_inv_std = None
if self.numb_aparam > 0:
add_data_requirement(
"aparam", self.numb_aparam, atomic=True, must=True, high_prec=False
)
self.aparam_avg = None
self.aparam_std = None
self.aparam_inv_std = None
Expand Down Expand Up @@ -939,3 +932,21 @@
},
}
return data

@property
def input_requirement(self) -> List[DataRequirementItem]:
"""Return data requirements needed for the model input."""
data_requirement = []
if self.numb_fparam > 0:
data_requirement.append(

Check warning on line 941 in deepmd/tf/fit/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/fit/ener.py#L939-L941

Added lines #L939 - L941 were not covered by tests
DataRequirementItem(
"fparam", self.numb_fparam, atomic=False, must=True, high_prec=False
)
)
if self.numb_aparam > 0:
data_requirement.append(

Check warning on line 947 in deepmd/tf/fit/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/fit/ener.py#L946-L947

Added lines #L946 - L947 were not covered by tests
DataRequirementItem(
"aparam", self.numb_aparam, atomic=True, must=True, high_prec=False
)
)
return data_requirement

Check warning on line 952 in deepmd/tf/fit/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/fit/ener.py#L952

Added line #L952 was not covered by tests
8 changes: 8 additions & 0 deletions deepmd/tf/fit/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
from deepmd.tf.utils import (
PluginVariant,
)
from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.utils.plugin import (
make_plugin_registry,
)
Expand Down Expand Up @@ -252,3 +255,8 @@
# prevent keyError
fitting_net_variables[f"{layer_name}{key}{suffix}/idt"] = 0.0
return fitting_net_variables

@property
def input_requirement(self) -> List[DataRequirementItem]:
"""Return data requirements needed for the model input."""
return []

Check warning on line 262 in deepmd/tf/fit/fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/fit/fitting.py#L262

Added line #L262 was not covered by tests
Loading
Loading