|
| 1 | +# SPDX-License-Identifier: LGPL-3.0-or-later |
| 2 | +import json |
| 3 | +import os |
| 4 | +import shutil |
| 5 | +import tempfile |
| 6 | +import unittest |
| 7 | +from copy import ( |
| 8 | + deepcopy, |
| 9 | +) |
| 10 | +from pathlib import ( |
| 11 | + Path, |
| 12 | +) |
| 13 | + |
| 14 | +import numpy as np |
| 15 | +import torch |
| 16 | + |
| 17 | +from deepmd.entrypoints.eval_desc import ( |
| 18 | + eval_desc, |
| 19 | +) |
| 20 | +from deepmd.pt.entrypoints.main import ( |
| 21 | + get_trainer, |
| 22 | +) |
| 23 | + |
| 24 | +from .model.test_permutation import ( |
| 25 | + model_se_e2_a, |
| 26 | +) |
| 27 | + |
| 28 | + |
| 29 | +class DPEvalDesc: |
| 30 | + def test_dp_eval_desc_1_frame(self) -> None: |
| 31 | + trainer = get_trainer(deepcopy(self.config)) |
| 32 | + with torch.device("cpu"): |
| 33 | + input_dict, label_dict, _ = trainer.get_data(is_train=False) |
| 34 | + has_spin = getattr(trainer.model, "has_spin", False) |
| 35 | + if callable(has_spin): |
| 36 | + has_spin = has_spin() |
| 37 | + if not has_spin: |
| 38 | + input_dict.pop("spin", None) |
| 39 | + input_dict["do_atomic_virial"] = True |
| 40 | + result = trainer.model(**input_dict) |
| 41 | + model = torch.jit.script(trainer.model) |
| 42 | + tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth") |
| 43 | + torch.jit.save(model, tmp_model.name) |
| 44 | + |
| 45 | + # Test eval_desc |
| 46 | + eval_desc( |
| 47 | + model=tmp_model.name, |
| 48 | + system=self.config["training"]["validation_data"]["systems"][0], |
| 49 | + datafile=None, |
| 50 | + output=self.output_dir, |
| 51 | + ) |
| 52 | + os.unlink(tmp_model.name) |
| 53 | + |
| 54 | + # Check that descriptor file was created |
| 55 | + system_name = os.path.basename( |
| 56 | + self.config["training"]["validation_data"]["systems"][0].rstrip("/") |
| 57 | + ) |
| 58 | + desc_file = os.path.join(self.output_dir, f"{system_name}.npy") |
| 59 | + self.assertTrue(os.path.exists(desc_file)) |
| 60 | + |
| 61 | + # Load and validate descriptor |
| 62 | + descriptors = np.load(desc_file) |
| 63 | + self.assertIsInstance(descriptors, np.ndarray) |
| 64 | + # Descriptors should be 3D: (nframes, natoms, ndesc) |
| 65 | + self.assertEqual(len(descriptors.shape), 3) # Should be 3D array |
| 66 | + self.assertGreater(descriptors.shape[0], 0) # Should have frames |
| 67 | + self.assertGreater(descriptors.shape[1], 0) # Should have atoms |
| 68 | + self.assertGreater(descriptors.shape[2], 0) # Should have descriptor dimensions |
| 69 | + |
| 70 | + def tearDown(self) -> None: |
| 71 | + for f in os.listdir("."): |
| 72 | + if f.startswith("model") and f.endswith(".pt"): |
| 73 | + os.remove(f) |
| 74 | + if f in ["lcurve.out", self.input_json]: |
| 75 | + os.remove(f) |
| 76 | + if f in ["stat_files"]: |
| 77 | + shutil.rmtree(f) |
| 78 | + # Clean up output directory |
| 79 | + if hasattr(self, "output_dir") and os.path.exists(self.output_dir): |
| 80 | + shutil.rmtree(self.output_dir) |
| 81 | + |
| 82 | + |
| 83 | +class TestDPEvalDescSeA(DPEvalDesc, unittest.TestCase): |
| 84 | + def setUp(self) -> None: |
| 85 | + self.output_dir = "test_eval_desc_output" |
| 86 | + input_json = str(Path(__file__).parent / "water" / "se_atten.json") |
| 87 | + with open(input_json) as f: |
| 88 | + self.config = json.load(f) |
| 89 | + self.config["training"]["numb_steps"] = 1 |
| 90 | + self.config["training"]["save_freq"] = 1 |
| 91 | + data_file = [str(Path(__file__).parent / "water" / "data" / "single")] |
| 92 | + self.config["training"]["training_data"]["systems"] = data_file |
| 93 | + self.config["training"]["validation_data"]["systems"] = data_file |
| 94 | + self.config["model"] = deepcopy(model_se_e2_a) |
| 95 | + self.input_json = "test_eval_desc.json" |
| 96 | + with open(self.input_json, "w") as fp: |
| 97 | + json.dump(self.config, fp, indent=4) |
| 98 | + |
| 99 | + |
| 100 | +if __name__ == "__main__": |
| 101 | + unittest.main() |
0 commit comments