Skip to content

Commit 2c5b95f

Browse files
feat: support zero-count elements in type_map for sort_atom_names
1 parent 7af8d74 commit 2c5b95f

File tree

1 file changed

+55
-21
lines changed

1 file changed

+55
-21
lines changed

dpdata/utils.py

Lines changed: 55 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,10 @@ def add_atom_names(data, atom_names):
6262

6363

6464
def sort_atom_names(data, type_map=None):
65-
"""Sort atom_names of the system and reorder atom_numbs and atom_types accoarding
65+
"""Sort atom_names of the system and reorder atom_numbs and atom_types according
6666
to atom_names. If type_map is not given, atom_names will be sorted by
67-
alphabetical order. If type_map is given, atom_names will be type_map.
67+
alphabetical order. If type_map is given, atom_names will be set to type_map,
68+
and zero-count elements are kept.
6869
6970
Parameters
7071
----------
@@ -74,28 +75,61 @@ def sort_atom_names(data, type_map=None):
7475
type_map
7576
"""
7677
if type_map is not None:
77-
# assign atom_names index to the specify order
78-
# atom_names must be a subset of type_map
79-
assert set(data["atom_names"]).issubset(set(type_map))
80-
# for the condition that type_map is a proper superset of atom_names
81-
# new_atoms = set(type_map) - set(data["atom_names"])
82-
new_atoms = [e for e in type_map if e not in data["atom_names"]]
83-
if new_atoms:
84-
data = add_atom_names(data, new_atoms)
85-
# index that will sort an array by type_map
86-
# a[as[a]] == b[as[b]] as == argsort
87-
# as[as[b]] == as^{-1}[b]
88-
# a[as[a][as[as[b]]]] = b[as[b][as^{-1}[b]]] = b[id]
89-
idx = np.argsort(data["atom_names"], kind="stable")[
90-
np.argsort(np.argsort(type_map, kind="stable"), kind="stable")
91-
]
78+
# assign atom_names index to the specified order
79+
# only active (numb > 0) atom names must be in type_map
80+
orig_names = data["atom_names"]
81+
orig_numbs = data["atom_numbs"]
82+
active_names = {name for name, numb in zip(orig_names, orig_numbs) if numb > 0}
83+
type_map_set = set(type_map)
84+
if not active_names.issubset(type_map_set):
85+
missing = active_names - type_map_set
86+
raise ValueError(f"Active atom types {missing} not in provided type_map.")
87+
88+
# for the condition that type_map is a proper superset of atom_names,
89+
# we allow new elements with atom_numb = 0
90+
new_names = list(type_map)
91+
new_numbs = []
92+
name_to_old_idx = {name: i for i, name in enumerate(orig_names)}
93+
94+
for name in new_names:
95+
if name in name_to_old_idx:
96+
new_numbs.append(orig_numbs[name_to_old_idx[name]])
97+
else:
98+
new_numbs.append(0)
99+
100+
# build mapping from old atom type index to new one
101+
# old_types[i] = j --> new_types[i] = type_map.index(atom_names[j])
102+
old_to_new_index = {}
103+
for old_idx, name in enumerate(orig_names):
104+
if name in type_map_set:
105+
new_idx = type_map.index(name)
106+
old_to_new_index[old_idx] = new_idx
107+
108+
# remap atom_types using the index mapping
109+
old_types = np.array(data["atom_types"])
110+
new_types = np.empty_like(old_types)
111+
for old_idx, new_idx in old_to_new_index.items():
112+
new_types[old_types == old_idx] = new_idx
113+
114+
# update data in-place
115+
data["atom_names"] = new_names
116+
data["atom_numbs"] = new_numbs
117+
data["atom_types"] = new_types
118+
92119
else:
93120
# index that will sort an array by alphabetical order
121+
# idx = argsort(atom_names) --> atom_names[idx] is sorted
94122
idx = np.argsort(data["atom_names"], kind="stable")
95-
# sort atom_names, atom_numbs, atom_types by idx
96-
data["atom_names"] = list(np.array(data["atom_names"])[idx])
97-
data["atom_numbs"] = list(np.array(data["atom_numbs"])[idx])
98-
data["atom_types"] = np.argsort(idx, kind="stable")[data["atom_types"]]
123+
# sort atom_names and atom_numbs by idx
124+
data["atom_names"] = list(np.array(data["atom_names"])[idx])
125+
data["atom_numbs"] = list(np.array(data["atom_numbs"])[idx])
126+
# to update atom_types: we need the inverse permutation of idx
127+
# because if old_type = i, and atom_names[i] moves to position j,
128+
# then the new type should be j.
129+
# inv_idx = argsort(idx) satisfies: inv_idx[idx[i]] = i
130+
inv_idx = np.argsort(idx, kind="stable")
131+
data["atom_types"] = inv_idx[data["atom_types"]]
132+
99133
return data
100134

101135

0 commit comments

Comments
 (0)