diff --git a/dpdata/utils.py b/dpdata/utils.py index 58a908cc..4c72bec9 100644 --- a/dpdata/utils.py +++ b/dpdata/utils.py @@ -62,9 +62,10 @@ def add_atom_names(data, atom_names): def sort_atom_names(data, type_map=None): - """Sort atom_names of the system and reorder atom_numbs and atom_types accoarding + """Sort atom_names of the system and reorder atom_numbs and atom_types according to atom_names. If type_map is not given, atom_names will be sorted by - alphabetical order. If type_map is given, atom_names will be type_map. + alphabetical order. If type_map is given, atom_names will be set to type_map, + and zero-count elements are kept. Parameters ---------- @@ -74,28 +75,61 @@ def sort_atom_names(data, type_map=None): type_map """ if type_map is not None: - # assign atom_names index to the specify order - # atom_names must be a subset of type_map - assert set(data["atom_names"]).issubset(set(type_map)) - # for the condition that type_map is a proper superset of atom_names - # new_atoms = set(type_map) - set(data["atom_names"]) - new_atoms = [e for e in type_map if e not in data["atom_names"]] - if new_atoms: - data = add_atom_names(data, new_atoms) - # index that will sort an array by type_map - # a[as[a]] == b[as[b]] as == argsort - # as[as[b]] == as^{-1}[b] - # a[as[a][as[as[b]]]] = b[as[b][as^{-1}[b]]] = b[id] - idx = np.argsort(data["atom_names"], kind="stable")[ - np.argsort(np.argsort(type_map, kind="stable"), kind="stable") - ] + # assign atom_names index to the specified order + # only active (numb > 0) atom names must be in type_map + orig_names = data["atom_names"] + orig_numbs = data["atom_numbs"] + active_names = {name for name, numb in zip(orig_names, orig_numbs) if numb > 0} + type_map_set = set(type_map) + if not active_names.issubset(type_map_set): + missing = active_names - type_map_set + raise ValueError(f"Active atom types {missing} not in provided type_map.") + + # for the condition that type_map is a proper superset of atom_names, + # we allow new elements with atom_numb = 0 + new_names = list(type_map) + new_numbs = [] + name_to_old_idx = {name: i for i, name in enumerate(orig_names)} + + for name in new_names: + if name in name_to_old_idx: + new_numbs.append(orig_numbs[name_to_old_idx[name]]) + else: + new_numbs.append(0) + + # build mapping from old atom type index to new one + # old_types[i] = j --> new_types[i] = type_map.index(atom_names[j]) + old_to_new_index = {} + for old_idx, name in enumerate(orig_names): + if name in type_map_set: + new_idx = type_map.index(name) + old_to_new_index[old_idx] = new_idx + + # remap atom_types using the index mapping + old_types = np.array(data["atom_types"]) + new_types = np.empty_like(old_types) + for old_idx, new_idx in old_to_new_index.items(): + new_types[old_types == old_idx] = new_idx + + # update data in-place + data["atom_names"] = new_names + data["atom_numbs"] = new_numbs + data["atom_types"] = new_types + else: # index that will sort an array by alphabetical order + # idx = argsort(atom_names) --> atom_names[idx] is sorted idx = np.argsort(data["atom_names"], kind="stable") - # sort atom_names, atom_numbs, atom_types by idx - data["atom_names"] = list(np.array(data["atom_names"])[idx]) - data["atom_numbs"] = list(np.array(data["atom_numbs"])[idx]) - data["atom_types"] = np.argsort(idx, kind="stable")[data["atom_types"]] + # sort atom_names and atom_numbs by idx + data["atom_names"] = list(np.array(data["atom_names"])[idx]) + data["atom_numbs"] = list(np.array(data["atom_numbs"])[idx]) + # to update atom_types: we need the inverse permutation of idx + # because if old_type = i, and atom_names[i] moves to position j, + # then the new type should be j. + # inv_idx = argsort(idx) satisfies: inv_idx[idx[i]] = i + inv_idx = np.argsort(idx, kind="stable") + data["atom_types"] = inv_idx[data["atom_types"]] + return data diff --git a/tests/test_type_map_utils.py b/tests/test_type_map_utils.py new file mode 100644 index 00000000..4ae9567d --- /dev/null +++ b/tests/test_type_map_utils.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import unittest + +import numpy as np + +from dpdata.utils import sort_atom_names + + +class TestSortAtomNames(unittest.TestCase): + def test_sort_atom_names_type_map(self): + # Test basic functionality with type_map + data = { + "atom_names": ["H", "O"], + "atom_numbs": [2, 1], + "atom_types": np.array([1, 0, 0]), + } + type_map = ["O", "H"] + result = sort_atom_names(data, type_map=type_map) + + self.assertEqual(result["atom_names"], ["O", "H"]) + self.assertEqual(result["atom_numbs"], [1, 2]) + np.testing.assert_array_equal(result["atom_types"], np.array([0, 1, 1])) + + def test_sort_atom_names_type_map_with_zero_atoms(self): + # Test with type_map that includes elements with zero atoms + data = { + "atom_names": ["H", "O"], + "atom_numbs": [2, 1], + "atom_types": np.array([1, 0, 0]), + } + type_map = ["O", "H", "C"] # C is not in atom_names but in type_map + result = sort_atom_names(data, type_map=type_map) + + self.assertEqual(result["atom_names"], ["O", "H", "C"]) + self.assertEqual(result["atom_numbs"], [1, 2, 0]) + np.testing.assert_array_equal(result["atom_types"], np.array([0, 1, 1])) + + def test_sort_atom_names_type_map_missing_active_types(self): + # Test that ValueError is raised when active atom types are missing from type_map + data = { + "atom_names": ["H", "O"], + "atom_numbs": [2, 1], # Both H and O are active (numb > 0) + "atom_types": np.array([1, 0, 0]), + } + type_map = ["H"] # O is active but missing from type_map + + with self.assertRaises(ValueError) as cm: + sort_atom_names(data, type_map=type_map) + + self.assertIn("Active atom types", str(cm.exception)) + self.assertIn("not in provided type_map", str(cm.exception)) + self.assertIn("O", str(cm.exception)) + + def test_sort_atom_names_without_type_map(self): + # Test sorting without type_map (alphabetical order) + data = { + "atom_names": ["Zn", "O", "H"], + "atom_numbs": [1, 1, 2], + "atom_types": np.array([0, 1, 2, 2]), + } + result = sort_atom_names(data) + + self.assertEqual(result["atom_names"], ["H", "O", "Zn"]) + self.assertEqual(result["atom_numbs"], [2, 1, 1]) + np.testing.assert_array_equal(result["atom_types"], np.array([2, 1, 0, 0])) + + def test_sort_atom_names_with_zero_count_elements_removed(self): + # Test the case where original elements are A B C, but counts are 0 1 2, + # which should be able to map to B C (removing A which has count 0) + # Example: A, B, C = Cl, O, C + data = { + "atom_names": ["Cl", "O", "C"], + "atom_numbs": [0, 1, 2], + "atom_types": np.array([1, 2, 2]), + } + type_map = ["O", "C"] # Cl is omitted because it has 0 atoms + result = sort_atom_names(data, type_map=type_map) + + self.assertEqual(result["atom_names"], ["O", "C"]) + self.assertEqual(result["atom_numbs"], [1, 2]) + np.testing.assert_array_equal(result["atom_types"], np.array([0, 1, 1])) + + +if __name__ == "__main__": + unittest.main()