Skip to content

Commit

Permalink
pt: expand systems before training (#3384)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Mar 2, 2024
1 parent bf4b473 commit c61ba88
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 17 deletions.
6 changes: 6 additions & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@
from deepmd.utils.compat import (
update_deepmd_input,
)
from deepmd.utils.data_system import (
process_systems,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -108,6 +111,9 @@ def prepare_trainer_input_single(
validation_dataset_params["systems"] if validation_dataset_params else None
)
training_systems = training_dataset_params["systems"]
training_systems = process_systems(training_systems)
if validation_systems is not None:
validation_systems = process_systems(validation_systems)

# stat files
stat_file_path_single = data_dict_single.get("stat_file", None)
Expand Down
55 changes: 38 additions & 17 deletions deepmd/utils/data_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Dict,
List,
Optional,
Union,
)

import numpy as np
Expand Down Expand Up @@ -711,30 +712,22 @@ def prob_sys_size_ext(keywords, nsystems, nbatch):
return sys_probs


def get_data(
jdata: Dict[str, Any], rcut, type_map, modifier, multi_task_mode=False
) -> DeepmdDataSystem:
"""Get the data system.
def process_systems(systems: Union[str, List[str]]) -> List[str]:
"""Process the user-input systems.
If it is a single directory, search for all the systems in the directory.
Check if the systems are valid.
Parameters
----------
jdata
The json data
rcut
The cut-off radius, not used
type_map
The type map
modifier
The data modifier
multi_task_mode
If in multi task mode
systems : str or list of str
The user-input systems
Returns
-------
DeepmdDataSystem
The data system
list of str
The valid systems
"""
systems = j_must_have(jdata, "systems")
if isinstance(systems, str):
systems = expand_sys_str(systems)
elif isinstance(systems, list):
Expand All @@ -756,6 +749,34 @@ def get_data(
msg = f"dir {ii} is not a valid data system dir"
log.fatal(msg)
raise OSError(msg, help_msg)
return systems


def get_data(
jdata: Dict[str, Any], rcut, type_map, modifier, multi_task_mode=False
) -> DeepmdDataSystem:
"""Get the data system.
Parameters
----------
jdata
The json data
rcut
The cut-off radius, not used
type_map
The type map
modifier
The data modifier
multi_task_mode
If in multi task mode
Returns
-------
DeepmdDataSystem
The data system
"""
systems = j_must_have(jdata, "systems")
systems = process_systems(systems)

batch_size = j_must_have(jdata, "batch_size")
sys_probs = jdata.get("sys_probs", None)
Expand Down
2 changes: 2 additions & 0 deletions deepmd/utils/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,8 @@ def is_file(self) -> bool:

def is_dir(self) -> bool:
"""Check if self is directory."""
if self._name == "/":
return True
if self._name not in self._keys:
return False
return isinstance(self.root[self._name], h5py.Group)
Expand Down

0 comments on commit c61ba88

Please sign in to comment.