@@ -62,9 +62,10 @@ def add_atom_names(data, atom_names):
6262
6363
6464def sort_atom_names (data , type_map = None ):
65- """Sort atom_names of the system and reorder atom_numbs and atom_types accoarding
65+ """Sort atom_names of the system and reorder atom_numbs and atom_types according
6666 to atom_names. If type_map is not given, atom_names will be sorted by
67- alphabetical order. If type_map is given, atom_names will be type_map.
67+ alphabetical order. If type_map is given, atom_names will be set to type_map,
68+ and zero-count elements are kept.
6869
6970 Parameters
7071 ----------
@@ -74,28 +75,61 @@ def sort_atom_names(data, type_map=None):
7475 type_map
7576 """
7677 if type_map is not None :
77- # assign atom_names index to the specify order
78- # atom_names must be a subset of type_map
79- assert set (data ["atom_names" ]).issubset (set (type_map ))
80- # for the condition that type_map is a proper superset of atom_names
81- # new_atoms = set(type_map) - set(data["atom_names"])
82- new_atoms = [e for e in type_map if e not in data ["atom_names" ]]
83- if new_atoms :
84- data = add_atom_names (data , new_atoms )
85- # index that will sort an array by type_map
86- # a[as[a]] == b[as[b]] as == argsort
87- # as[as[b]] == as^{-1}[b]
88- # a[as[a][as[as[b]]]] = b[as[b][as^{-1}[b]]] = b[id]
89- idx = np .argsort (data ["atom_names" ], kind = "stable" )[
90- np .argsort (np .argsort (type_map , kind = "stable" ), kind = "stable" )
91- ]
78+ # assign atom_names index to the specified order
79+ # only active (numb > 0) atom names must be in type_map
80+ orig_names = data ["atom_names" ]
81+ orig_numbs = data ["atom_numbs" ]
82+ active_names = {name for name , numb in zip (orig_names , orig_numbs ) if numb > 0 }
83+ type_map_set = set (type_map )
84+ if not active_names .issubset (type_map_set ):
85+ missing = active_names - type_map_set
86+ raise ValueError (f"Active atom types { missing } not in provided type_map." )
87+
88+ # for the condition that type_map is a proper superset of atom_names,
89+ # we allow new elements with atom_numb = 0
90+ new_names = list (type_map )
91+ new_numbs = []
92+ name_to_old_idx = {name : i for i , name in enumerate (orig_names )}
93+
94+ for name in new_names :
95+ if name in name_to_old_idx :
96+ new_numbs .append (orig_numbs [name_to_old_idx [name ]])
97+ else :
98+ new_numbs .append (0 )
99+
100+ # build mapping from old atom type index to new one
101+ # old_types[i] = j --> new_types[i] = type_map.index(atom_names[j])
102+ old_to_new_index = {}
103+ for old_idx , name in enumerate (orig_names ):
104+ if name in type_map_set :
105+ new_idx = type_map .index (name )
106+ old_to_new_index [old_idx ] = new_idx
107+
108+ # remap atom_types using the index mapping
109+ old_types = np .array (data ["atom_types" ])
110+ new_types = np .empty_like (old_types )
111+ for old_idx , new_idx in old_to_new_index .items ():
112+ new_types [old_types == old_idx ] = new_idx
113+
114+ # update data in-place
115+ data ["atom_names" ] = new_names
116+ data ["atom_numbs" ] = new_numbs
117+ data ["atom_types" ] = new_types
118+
92119 else :
93120 # index that will sort an array by alphabetical order
121+ # idx = argsort(atom_names) --> atom_names[idx] is sorted
94122 idx = np .argsort (data ["atom_names" ], kind = "stable" )
95- # sort atom_names, atom_numbs, atom_types by idx
96- data ["atom_names" ] = list (np .array (data ["atom_names" ])[idx ])
97- data ["atom_numbs" ] = list (np .array (data ["atom_numbs" ])[idx ])
98- data ["atom_types" ] = np .argsort (idx , kind = "stable" )[data ["atom_types" ]]
123+ # sort atom_names and atom_numbs by idx
124+ data ["atom_names" ] = list (np .array (data ["atom_names" ])[idx ])
125+ data ["atom_numbs" ] = list (np .array (data ["atom_numbs" ])[idx ])
126+ # to update atom_types: we need the inverse permutation of idx
127+ # because if old_type = i, and atom_names[i] moves to position j,
128+ # then the new type should be j.
129+ # inv_idx = argsort(idx) satisfies: inv_idx[idx[i]] = i
130+ inv_idx = np .argsort (idx , kind = "stable" )
131+ data ["atom_types" ] = inv_idx [data ["atom_types" ]]
132+
99133 return data
100134
101135
0 commit comments