Skip to content

Commit

Permalink
support collecting MultiSystems (#1422)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and pre-commit-ci[bot] authored Dec 11, 2023
1 parent d90939c commit c3d8f96
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
5 changes: 4 additions & 1 deletion dpgen/collect/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import dpdata

from dpgen.generator.run import data_system_fmt
from dpgen.util import expand_sys_str


def collect_data(
Expand All @@ -18,7 +19,8 @@ def collect_data(
# goto input
cwd = os.getcwd()
os.chdir(target_folder)
jdata = json.load(open(param_file))
with open(param_file) as fp:
jdata = json.load(fp)
sys_configs_prefix = jdata.get("sys_configs_prefix", "")
sys_configs = jdata.get("sys_configs", [])
if verbose:
Expand Down Expand Up @@ -46,6 +48,7 @@ def collect_data(
for ii in range(len(iters)):
iter_data = glob.glob(os.path.join(iters[ii], "02.fp", "data.[0-9]*[0-9]"))
iter_data.sort()
iter_data = sum([expand_sys_str(ii) for ii in iter_data], [])
for jj in iter_data:
sys = dpdata.LabeledSystem(jj, fmt="deepmd/npy")
if merge:
Expand Down
33 changes: 33 additions & 0 deletions tests/test_collect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import json
import tempfile
import unittest
from pathlib import Path

import dpdata

from dpgen.collect.collect import collect_data


class TestCollectData(unittest.TestCase):
def setUp(self):
self.data = dpdata.LabeledSystem(
Path(__file__).parent / "generator" / "data" / "deepmd", fmt="deepmd/npy"
)

def test_collect_data(self):
with tempfile.TemporaryDirectory() as inpdir, tempfile.TemporaryDirectory() as outdir, tempfile.NamedTemporaryFile() as param_file:
self.data.to_deepmd_npy(Path(inpdir) / "iter.000000" / "02.fp" / "data.000")
self.data.to_deepmd_npy(
Path(inpdir) / "iter.000001" / "02.fp" / "data.000" / "aa"
)
self.data.to_deepmd_npy(
Path(inpdir) / "iter.000001" / "02.fp" / "data.000" / "bb"
)
with open(param_file.name, "w") as fp:
json.dump(
{"sys_configs": ["sys1"], "model_devi_jobs": [{}, {}, {}]}, fp
)

collect_data(inpdir, param_file.name, outdir, verbose=True)
ms = dpdata.MultiSystems().from_deepmd_npy(outdir)
self.assertEqual(ms.get_nframes(), self.data.get_nframes() * 3)

0 comments on commit c3d8f96

Please sign in to comment.