Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 55 additions & 21 deletions dpdata/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,10 @@ def add_atom_names(data, atom_names):


def sort_atom_names(data, type_map=None):
"""Sort atom_names of the system and reorder atom_numbs and atom_types accoarding
"""Sort atom_names of the system and reorder atom_numbs and atom_types according
to atom_names. If type_map is not given, atom_names will be sorted by
alphabetical order. If type_map is given, atom_names will be type_map.
alphabetical order. If type_map is given, atom_names will be set to type_map,
and zero-count elements are kept.

Parameters
----------
Expand All @@ -74,28 +75,61 @@ def sort_atom_names(data, type_map=None):
type_map
"""
if type_map is not None:
# assign atom_names index to the specify order
# atom_names must be a subset of type_map
assert set(data["atom_names"]).issubset(set(type_map))
# for the condition that type_map is a proper superset of atom_names
# new_atoms = set(type_map) - set(data["atom_names"])
new_atoms = [e for e in type_map if e not in data["atom_names"]]
if new_atoms:
data = add_atom_names(data, new_atoms)
# index that will sort an array by type_map
# a[as[a]] == b[as[b]] as == argsort
# as[as[b]] == as^{-1}[b]
# a[as[a][as[as[b]]]] = b[as[b][as^{-1}[b]]] = b[id]
idx = np.argsort(data["atom_names"], kind="stable")[
np.argsort(np.argsort(type_map, kind="stable"), kind="stable")
]
# assign atom_names index to the specified order
# only active (numb > 0) atom names must be in type_map
orig_names = data["atom_names"]
orig_numbs = data["atom_numbs"]
active_names = {name for name, numb in zip(orig_names, orig_numbs) if numb > 0}
type_map_set = set(type_map)
if not active_names.issubset(type_map_set):
missing = active_names - type_map_set
raise ValueError(f"Active atom types {missing} not in provided type_map.")

# for the condition that type_map is a proper superset of atom_names,
# we allow new elements with atom_numb = 0
new_names = list(type_map)
new_numbs = []
name_to_old_idx = {name: i for i, name in enumerate(orig_names)}

for name in new_names:
if name in name_to_old_idx:
new_numbs.append(orig_numbs[name_to_old_idx[name]])
else:
new_numbs.append(0)

# build mapping from old atom type index to new one
# old_types[i] = j --> new_types[i] = type_map.index(atom_names[j])
old_to_new_index = {}
for old_idx, name in enumerate(orig_names):
if name in type_map_set:
new_idx = type_map.index(name)
old_to_new_index[old_idx] = new_idx

# remap atom_types using the index mapping
old_types = np.array(data["atom_types"])
new_types = np.empty_like(old_types)
for old_idx, new_idx in old_to_new_index.items():
new_types[old_types == old_idx] = new_idx

# update data in-place
data["atom_names"] = new_names
data["atom_numbs"] = new_numbs
data["atom_types"] = new_types

else:
# index that will sort an array by alphabetical order
# idx = argsort(atom_names) --> atom_names[idx] is sorted
idx = np.argsort(data["atom_names"], kind="stable")
# sort atom_names, atom_numbs, atom_types by idx
data["atom_names"] = list(np.array(data["atom_names"])[idx])
data["atom_numbs"] = list(np.array(data["atom_numbs"])[idx])
data["atom_types"] = np.argsort(idx, kind="stable")[data["atom_types"]]
# sort atom_names and atom_numbs by idx
data["atom_names"] = list(np.array(data["atom_names"])[idx])
data["atom_numbs"] = list(np.array(data["atom_numbs"])[idx])
# to update atom_types: we need the inverse permutation of idx
# because if old_type = i, and atom_names[i] moves to position j,
# then the new type should be j.
# inv_idx = argsort(idx) satisfies: inv_idx[idx[i]] = i
inv_idx = np.argsort(idx, kind="stable")
data["atom_types"] = inv_idx[data["atom_types"]]

return data


Expand Down
86 changes: 86 additions & 0 deletions tests/test_type_map_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from __future__ import annotations

import unittest

import numpy as np

from dpdata.utils import sort_atom_names


class TestSortAtomNames(unittest.TestCase):
def test_sort_atom_names_type_map(self):
# Test basic functionality with type_map
data = {
"atom_names": ["H", "O"],
"atom_numbs": [2, 1],
"atom_types": np.array([1, 0, 0]),
}
type_map = ["O", "H"]
result = sort_atom_names(data, type_map=type_map)

self.assertEqual(result["atom_names"], ["O", "H"])
self.assertEqual(result["atom_numbs"], [1, 2])
np.testing.assert_array_equal(result["atom_types"], np.array([0, 1, 1]))

def test_sort_atom_names_type_map_with_zero_atoms(self):
# Test with type_map that includes elements with zero atoms
data = {
"atom_names": ["H", "O"],
"atom_numbs": [2, 1],
"atom_types": np.array([1, 0, 0]),
}
type_map = ["O", "H", "C"] # C is not in atom_names but in type_map
result = sort_atom_names(data, type_map=type_map)

self.assertEqual(result["atom_names"], ["O", "H", "C"])
self.assertEqual(result["atom_numbs"], [1, 2, 0])
np.testing.assert_array_equal(result["atom_types"], np.array([0, 1, 1]))

def test_sort_atom_names_type_map_missing_active_types(self):
# Test that ValueError is raised when active atom types are missing from type_map
data = {
"atom_names": ["H", "O"],
"atom_numbs": [2, 1], # Both H and O are active (numb > 0)
"atom_types": np.array([1, 0, 0]),
}
type_map = ["H"] # O is active but missing from type_map

with self.assertRaises(ValueError) as cm:
sort_atom_names(data, type_map=type_map)

self.assertIn("Active atom types", str(cm.exception))
self.assertIn("not in provided type_map", str(cm.exception))
self.assertIn("O", str(cm.exception))

def test_sort_atom_names_without_type_map(self):
# Test sorting without type_map (alphabetical order)
data = {
"atom_names": ["Zn", "O", "H"],
"atom_numbs": [1, 1, 2],
"atom_types": np.array([0, 1, 2, 2]),
}
result = sort_atom_names(data)

self.assertEqual(result["atom_names"], ["H", "O", "Zn"])
self.assertEqual(result["atom_numbs"], [2, 1, 1])
np.testing.assert_array_equal(result["atom_types"], np.array([2, 1, 0, 0]))

def test_sort_atom_names_with_zero_count_elements_removed(self):
# Test the case where original elements are A B C, but counts are 0 1 2,
# which should be able to map to B C (removing A which has count 0)
# Example: A, B, C = Cl, O, C
data = {
"atom_names": ["Cl", "O", "C"],
"atom_numbs": [0, 1, 2],
"atom_types": np.array([1, 2, 2]),
}
type_map = ["O", "C"] # Cl is omitted because it has 0 atoms
result = sort_atom_names(data, type_map=type_map)

self.assertEqual(result["atom_names"], ["O", "C"])
self.assertEqual(result["atom_numbs"], [1, 2])
np.testing.assert_array_equal(result["atom_types"], np.array([0, 1, 1]))


if __name__ == "__main__":
unittest.main()