33import unittest
44
55import numpy as np
6- from context import dpdata
76
87from dpdata .utils import sort_atom_names
98
@@ -18,11 +17,11 @@ def test_sort_atom_names_type_map(self):
1817 }
1918 type_map = ["O" , "H" ]
2019 result = sort_atom_names (data , type_map = type_map )
21-
20+
2221 self .assertEqual (result ["atom_names" ], ["O" , "H" ])
2322 self .assertEqual (result ["atom_numbs" ], [1 , 2 ])
2423 np .testing .assert_array_equal (result ["atom_types" ], np .array ([0 , 1 , 1 ]))
25-
24+
2625 def test_sort_atom_names_type_map_with_zero_atoms (self ):
2726 # Test with type_map that includes elements with zero atoms
2827 data = {
@@ -32,11 +31,11 @@ def test_sort_atom_names_type_map_with_zero_atoms(self):
3231 }
3332 type_map = ["O" , "H" , "C" ] # C is not in atom_names but in type_map
3433 result = sort_atom_names (data , type_map = type_map )
35-
34+
3635 self .assertEqual (result ["atom_names" ], ["O" , "H" , "C" ])
3736 self .assertEqual (result ["atom_numbs" ], [1 , 2 , 0 ])
3837 np .testing .assert_array_equal (result ["atom_types" ], np .array ([0 , 1 , 1 ]))
39-
38+
4039 def test_sort_atom_names_type_map_missing_active_types (self ):
4140 # Test that ValueError is raised when active atom types are missing from type_map
4241 data = {
@@ -45,14 +44,14 @@ def test_sort_atom_names_type_map_missing_active_types(self):
4544 "atom_types" : np .array ([1 , 0 , 0 ]),
4645 }
4746 type_map = ["H" ] # O is active but missing from type_map
48-
47+
4948 with self .assertRaises (ValueError ) as cm :
5049 sort_atom_names (data , type_map = type_map )
51-
50+
5251 self .assertIn ("Active atom types" , str (cm .exception ))
5352 self .assertIn ("not in provided type_map" , str (cm .exception ))
5453 self .assertIn ("O" , str (cm .exception ))
55-
54+
5655 def test_sort_atom_names_without_type_map (self ):
5756 # Test sorting without type_map (alphabetical order)
5857 data = {
@@ -61,7 +60,7 @@ def test_sort_atom_names_without_type_map(self):
6160 "atom_types" : np .array ([0 , 1 , 2 , 2 ]),
6261 }
6362 result = sort_atom_names (data )
64-
63+
6564 self .assertEqual (result ["atom_names" ], ["H" , "O" , "Zn" ])
6665 self .assertEqual (result ["atom_numbs" ], [2 , 1 , 1 ])
6766 np .testing .assert_array_equal (result ["atom_types" ], np .array ([2 , 1 , 0 , 0 ]))
@@ -76,11 +75,11 @@ def test_sort_atom_names_with_zero_count_elements_removed(self):
7675 }
7776 type_map = ["O" , "C" ] # A is omitted because it has 0 atoms
7877 result = sort_atom_names (data , type_map = type_map )
79-
78+
8079 self .assertEqual (result ["atom_names" ], ["O" , "C" ])
8180 self .assertEqual (result ["atom_numbs" ], [1 , 2 ])
8281 np .testing .assert_array_equal (result ["atom_types" ], np .array ([0 , 1 , 1 ]))
8382
8483
8584if __name__ == "__main__" :
86- unittest .main ()
85+ unittest .main ()
0 commit comments