-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathconvert_mptrj_to_xyz.py
119 lines (98 loc) · 3.41 KB
/
convert_mptrj_to_xyz.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import json
import random
from typing import Any
import ase.io
import ase.units
import numpy as np
from ase import Atoms
from ase.calculators.singlepoint import SinglePointCalculator
from ase.data import atomic_numbers
from ase.stress import full_3x3_to_voigt_6_stress
from pymatviz.enums import Key
from tqdm import tqdm
# Convert stress from kBar to eV/A^3 and use ASE sign convention
kbar_to_evpa3 = -0.1 / ase.units.GPa
filename = "MPtrj_2022.9_full.json"
train_ratio, val_ratio, test_ratio = 0.9, 0.1, 0.0
print(f"Reading {filename} ...", flush=True)
with open(filename) as file:
data = json.load(file)
def get_id_train_val_test(
total_size: int,
train_ratio: float,
val_ratio: float,
test_ratio: float,
split_seed: int = 123,
) -> tuple[list[int], list[int], list[int]]:
"""Get train, val, test IDs."""
if train_ratio + val_ratio + test_ratio > 1:
raise ValueError("train_ratio + val_ratio + test_ratio is over 1.0")
n_train = int(train_ratio * total_size)
n_val = int(val_ratio * total_size)
n_test = int(test_ratio * total_size)
ids = list(np.arange(total_size))
random.seed(split_seed)
random.shuffle(ids)
id_train = ids[:n_train]
id_val = ids[-(n_val + n_test) : -n_test]
id_test = ids[-n_test:] if n_test != 0 else []
return id_train, id_val, id_test
id_train, id_val, id_test = get_id_train_val_test(
total_size=len(data),
train_ratio=train_ratio,
val_ratio=val_ratio,
test_ratio=test_ratio,
)
info_keys = [
"uncorrected_total_energy",
"ef_per_atom",
"e_per_atom_relaxed",
"ef_per_atom_relaxed",
"magmom",
"bandgap",
"mp_id",
]
def chgnet_to_ase_atoms(datum: dict[str, dict[str, Any]]) -> list[Atoms]:
atoms_list = []
for mat_id, dtm in datum.items():
energy = dtm["uncorrected_total_energy"]
force = dtm["force"]
stress = full_3x3_to_voigt_6_stress(dtm["stress"]) # internal stress
stress *= kbar_to_evpa3 # to eV/Angstrom^3
struct = dtm["structure"]
cell = struct["lattice"]["matrix"]
sites = struct["sites"]
species = [atomic_numbers[site["species"][0]["element"]] for site in sites]
pos = [site["xyz"] for site in sites]
atoms = Atoms(species, pos, cell=cell, pbc=True)
calc_results = {
Key.energy: energy,
Key.free_energy: energy,
Key.forces: force,
Key.stress: stress,
}
calculator = SinglePointCalculator(atoms, **calc_results)
atoms = calculator.get_atoms()
info = {
"data_from": "MP-CHGNet",
"material_id": mat_id.split("-")[0] + "-" + mat_id.split("-")[1],
"calc_id": mat_id.split("-")[2],
"ionic_step_id": mat_id.split("-")[3],
}
for key in info_keys:
info[key] = dtm[key]
atoms.info = info
atoms_list += [atoms]
return atoms_list
dataset = list(data.values())
train_set, val_set, test_set = [], [], []
for idx in tqdm(id_train):
train_set.extend(chgnet_to_ase_atoms(dataset[idx]))
ase.io.write("train.extxyz", train_set, "extxyz", append=True)
for idx in tqdm(id_val):
val_set.extend(chgnet_to_ase_atoms(dataset[idx]))
ase.io.write("valid.extxyz", val_set, "extxyz", append=True)
for idx in tqdm(id_test):
test_set.extend(chgnet_to_ase_atoms(dataset[idx]))
ase.io.write("test.extxyz", test_set, "extxyz", append=True)
print("Done!")