diff --git a/pyproject.toml b/pyproject.toml index 4801a5537f6..d3272d1b437 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,7 @@ dependencies = [ "requests>=2.32", "ruamel.yaml>=0.17.0", "scipy>=1.13.0", - "spglib>=2.0.2", + "spglib>=2.5.0", "sympy>=1.2", "tabulate>=0.9", "tqdm>=4.60", diff --git a/src/pymatgen/symmetry/analyzer.py b/src/pymatgen/symmetry/analyzer.py index 17b29e3e69e..f404aa2170b 100644 --- a/src/pymatgen/symmetry/analyzer.py +++ b/src/pymatgen/symmetry/analyzer.py @@ -35,9 +35,12 @@ if TYPE_CHECKING: from typing import Any, Literal + from numpy.typing import NDArray from pymatgen.core import Element, Species from pymatgen.core.sites import Site from pymatgen.symmetry.groups import CrystalSystem + from pymatgen.util.typing import Kpoint + from spglib import SpglibDataset LatticeType = Literal["cubic", "hexagonal", "monoclinic", "orthorhombic", "rhombohedral", "tetragonal", "triclinic"] @@ -74,7 +77,12 @@ class SpacegroupAnalyzer: Uses spglib to perform various symmetry finding operations. """ - def __init__(self, structure: Structure, symprec: float | None = 0.01, angle_tolerance: float = 5) -> None: + def __init__( + self, + structure: Structure, + symprec: float | None = 0.01, + angle_tolerance: float = 5, + ) -> None: """ Args: structure (Structure/IStructure): Structure to find symmetry @@ -93,7 +101,6 @@ def __init__(self, structure: Structure, symprec: float | None = 0.01, angle_tol self._site_props = structure.site_properties unique_species: list[Element | Species] = [] zs = [] - magmoms = [] for species, group in itertools.groupby(structure, key=lambda s: s.species): if species in unique_species: ind = unique_species.index(species) @@ -106,6 +113,7 @@ def __init__(self, structure: Structure, symprec: float | None = 0.01, angle_tol getattr(specie, "spin", None) is not None for specie in structure.types_of_species ) + magmoms = [] for site in structure: if hasattr(site, "magmom"): magmoms.append(site.magmom) @@ -139,7 +147,7 @@ def get_space_group_symbol(self) -> str: Returns: str: Spacegroup symbol for structure. """ - return self._space_group_data["international"] + return self._space_group_data.international def get_space_group_number(self) -> int: """Get the international spacegroup number (e.g., 62) for structure. @@ -147,7 +155,7 @@ def get_space_group_number(self) -> int: Returns: int: International spacegroup number for structure. """ - return int(self._space_group_data["number"]) + return int(self._space_group_data.number) def get_space_group_operations(self) -> SpacegroupOperations: """Get the SpacegroupOperations for the Structure. @@ -167,7 +175,7 @@ def get_hall(self) -> str: Returns: str: Hall symbol """ - return self._space_group_data["hall"] + return self._space_group_data.hall def get_point_group_symbol(self) -> str: """Get the point group associated with the structure. @@ -175,7 +183,7 @@ def get_point_group_symbol(self) -> str: Returns: Pointgroup: Point group for structure. """ - rotations = self._space_group_data["rotations"] + rotations = self._space_group_data.rotations # passing a 0-length rotations list to spglib can segfault if len(rotations) == 0: return "1" @@ -191,10 +199,10 @@ def get_crystal_system(self) -> CrystalSystem: Returns: str: Crystal system for structure """ - n = self._space_group_data["number"] + n = self._space_group_data.number - # not using isinstance(n, int) to allow 0-decimal floats - if not (n == int(n) and 0 < n < 231): + # Not using isinstance(n, int) to allow 0-decimal floats + if n != int(n) or not 0 < n < 231: raise ValueError(f"Received invalid space group {n}") if 0 < n < 3: @@ -222,33 +230,31 @@ def get_lattice_type(self) -> LatticeType: Returns: str: Lattice type for structure """ - spg_num = self._space_group_data["number"] + spg_num = self._space_group_data.number system = self.get_crystal_system() - if spg_num in (146, 148, 155, 160, 161, 166, 167): + if spg_num in {146, 148, 155, 160, 161, 166, 167}: return "rhombohedral" - if system == "trigonal": - return "hexagonal" - return system + return "hexagonal" if system == "trigonal" else system - def get_symmetry_dataset(self): - """Get the symmetry dataset as a dict. + def get_symmetry_dataset(self) -> SpglibDataset: + """Get the symmetry dataset as a SpglibDataset. Returns: - dict: With the following properties: + frozen dict: With the following properties: number: International space group number international: International symbol hall: Hall symbol transformation_matrix: Transformation matrix from lattice of - input cell to Bravais lattice L^bravais = L^original * Tmat - origin shift: Origin shift in the setting of "Bravais lattice" - rotations, translations: Rotation matrices and translation - vectors. Space group operations are obtained by - [(r,t) for r, t in zip(rotations, translations)] + input cell to Bravais lattice L^bravais = L^original * Tmat + origin shift: Origin shift in the setting of "Bravais lattice" + rotations, translations: Rotation matrices and translation + vectors. Space group operations are obtained by + [(r,t) for r, t in zip(rotations, translations)] wyckoffs: Wyckoff letters """ return self._space_group_data - def _get_symmetry(self): + def _get_symmetry(self) -> tuple[NDArray, NDArray]: """Get the symmetry operations associated with the structure. Returns: @@ -269,16 +275,16 @@ def _get_symmetry(self): # [1e-4, 2e-4, 1e-4] # (these are in fractional coordinates, so should be small denominator # fractions) - translations = [] - for t in dct["translations"]: - translations.append([float(Fraction(c).limit_denominator(1000)) for c in t]) - translations = np.array(translations) + _translations: list = [] + for trans in dct["translations"]: + _translations.append([float(Fraction(c).limit_denominator(1000)) for c in trans]) + translations: NDArray = np.array(_translations) - # fractional translations of 1 are more simply 0 + # Fractional translations of 1 are more simply 0 translations[np.abs(translations) == 1] = 0 return dct["rotations"], translations - def get_symmetry_operations(self, cartesian=False): + def get_symmetry_operations(self, cartesian: bool = False) -> list[SymmOp]: """Return symmetry operations as a list of SymmOp objects. By default returns fractional coord sym_ops. But Cartesian can be returned too. @@ -297,7 +303,7 @@ def get_symmetry_operations(self, cartesian=False): sym_ops.append(op) return sym_ops - def get_point_group_operations(self, cartesian=False): + def get_point_group_operations(self, cartesian: bool = False) -> list[SymmOp]: """Return symmetry operations as a list of SymmOp objects. By default returns fractional coord symm ops. But Cartesian can be returned too. @@ -324,7 +330,7 @@ def get_point_group_operations(self, cartesian=False): symm_ops.append(op) return symm_ops - def get_symmetrized_structure(self): + def get_symmetrized_structure(self) -> SymmetrizedStructure: """Get a symmetrized structure. A symmetrized structure is one where the sites have been grouped into symmetrically equivalent groups. @@ -337,9 +343,9 @@ def get_symmetrized_structure(self): self.get_space_group_number(), self.get_symmetry_operations(), ) - return SymmetrizedStructure(self._structure, spg_ops, sym_dataset["equivalent_atoms"], sym_dataset["wyckoffs"]) + return SymmetrizedStructure(self._structure, spg_ops, sym_dataset.equivalent_atoms, sym_dataset.wyckoffs) - def get_refined_structure(self, keep_site_properties=False): + def get_refined_structure(self, keep_site_properties: bool = False) -> Structure: """Get the refined structure based on detected symmetry. The refined structure is a *conventional* cell setting with atoms moved to the expected symmetry positions. @@ -368,7 +374,7 @@ def get_refined_structure(self, keep_site_properties=False): struct = Structure(lattice, species, scaled_positions, site_properties=site_properties) return struct.get_sorted_structure() - def find_primitive(self, keep_site_properties=False): + def find_primitive(self, keep_site_properties: bool = False) -> Structure: """Find a primitive version of the unit cell. Args: @@ -390,8 +396,8 @@ def find_primitive(self, keep_site_properties=False): species = [self._unique_species[i - 1] for i in numbers] if keep_site_properties: site_properties = {} - for k, v in self._site_props.items(): - site_properties[k] = [v[i - 1] for i in numbers] + for key, val in self._site_props.items(): + site_properties[key] = [val[i - 1] for i in numbers] else: site_properties = None @@ -399,9 +405,13 @@ def find_primitive(self, keep_site_properties=False): lattice, species, scaled_positions, to_unit_cell=True, site_properties=site_properties ).get_reduced_structure() - def get_ir_reciprocal_mesh(self, mesh=(10, 10, 10), is_shift=(0, 0, 0)): - """k-point mesh of the Brillouin zone generated taken into account symmetry.The - method returns the irreducible kpoints of the mesh and their weights. + def get_ir_reciprocal_mesh( + self, + mesh: tuple[int, int, int] = (10, 10, 10), + is_shift: tuple[float, float, float] = (0, 0, 0), + ) -> list[tuple[Kpoint, float]]: + """k-point mesh of the Brillouin zone generated taken into account symmetry. + The method returns the irreducible kpoints of the mesh and their weights. Args: mesh (3x1 array): The number of kpoint for the mesh needed in @@ -418,11 +428,15 @@ def get_ir_reciprocal_mesh(self, mesh=(10, 10, 10), is_shift=(0, 0, 0)): mapping, grid = spglib.get_ir_reciprocal_mesh(np.array(mesh), self._cell, is_shift=shift, symprec=self._symprec) results = [] - for i, count in zip(*np.unique(mapping, return_counts=True)): - results.append(((grid[i] + shift * (0.5, 0.5, 0.5)) / mesh, count)) + for idx, count in zip(*np.unique(mapping, return_counts=True)): + results.append(((grid[idx] + shift * (0.5, 0.5, 0.5)) / mesh, count)) return results - def get_ir_reciprocal_mesh_map(self, mesh=(10, 10, 10), is_shift=(0, 0, 0)): + def get_ir_reciprocal_mesh_map( + self, + mesh: tuple[int, int, int] = (10, 10, 10), + is_shift: tuple[float, float, float] = (0, 0, 0), + ) -> tuple[NDArray, NDArray]: """Same as 'get_ir_reciprocal_mesh' but the full grid together with the mapping that maps a reducible to an irreducible kpoint is returned. @@ -445,7 +459,10 @@ def get_ir_reciprocal_mesh_map(self, mesh=(10, 10, 10), is_shift=(0, 0, 0)): return grid_fractional_coords, mapping @cite_conventional_cell_algo - def get_conventional_to_primitive_transformation_matrix(self, international_monoclinic=True): + def get_conventional_to_primitive_transformation_matrix( + self, + international_monoclinic: bool = True, + ) -> NDArray: """Get the transformation matrix to transform a conventional unit cell to a primitive cell according to certain standards the standards are defined in Setyawan, W., & Curtarolo, S. (2010). High-throughput electronic band structure @@ -466,30 +483,32 @@ def get_conventional_to_primitive_transformation_matrix(self, international_mono return np.eye(3) if lattice == "rhombohedral": - # check if the conventional representation is hexagonal or + # Check if the conventional representation is hexagonal or # rhombohedral lengths = conv.lattice.lengths if abs(lengths[0] - lengths[2]) < 0.0001: - transf = np.eye - else: - transf = np.array([[-1, 1, 1], [2, 1, 1], [-1, -2, 1]], dtype=np.float64) / 3 + return np.eye + return np.array([[-1, 1, 1], [2, 1, 1], [-1, -2, 1]], dtype=np.float64) / 3 + + if "I" in self.get_space_group_symbol(): + return np.array([[-1, 1, 1], [1, -1, 1], [1, 1, -1]], dtype=np.float64) / 2 - elif "I" in self.get_space_group_symbol(): - transf = np.array([[-1, 1, 1], [1, -1, 1], [1, 1, -1]], dtype=np.float64) / 2 - elif "F" in self.get_space_group_symbol(): - transf = np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]], dtype=np.float64) / 2 - elif "C" in self.get_space_group_symbol() or "A" in self.get_space_group_symbol(): + if "F" in self.get_space_group_symbol(): + return np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]], dtype=np.float64) / 2 + + if "C" in self.get_space_group_symbol() or "A" in self.get_space_group_symbol(): if self.get_crystal_system() == "monoclinic": - transf = np.array([[1, 1, 0], [-1, 1, 0], [0, 0, 2]], dtype=np.float64) / 2 - else: - transf = np.array([[1, -1, 0], [1, 1, 0], [0, 0, 2]], dtype=np.float64) / 2 - else: - transf = np.eye(3) + return np.array([[1, 1, 0], [-1, 1, 0], [0, 0, 2]], dtype=np.float64) / 2 + return np.array([[1, -1, 0], [1, 1, 0], [0, 0, 2]], dtype=np.float64) / 2 - return transf + return np.eye(3) @cite_conventional_cell_algo - def get_primitive_standard_structure(self, international_monoclinic=True, keep_site_properties=False): + def get_primitive_standard_structure( + self, + international_monoclinic: bool = True, + keep_site_properties: bool = False, + ) -> Structure: """Get a structure with a primitive cell according to certain standards. The standards are defined in Setyawan, W., & Curtarolo, S. (2010). High-throughput electronic band structure calculations: Challenges and tools. Computational @@ -522,7 +541,7 @@ def get_primitive_standard_structure(self, international_monoclinic=True, keep_s international_monoclinic=international_monoclinic ) - new_sites = [] + new_sites: list[PeriodicSite] = [] lattice = Lattice(np.dot(transf, conv.lattice.matrix)) for site in conv: new_s = PeriodicSite( @@ -568,7 +587,11 @@ def get_primitive_standard_structure(self, international_monoclinic=True, keep_s return Structure.from_sites(new_sites) @cite_conventional_cell_algo - def get_conventional_standard_structure(self, international_monoclinic=True, keep_site_properties=False): + def get_conventional_standard_structure( + self, + international_monoclinic: bool = True, + keep_site_properties: bool = False, + ) -> Structure: """Get a structure with a conventional cell according to certain standards. The standards are defined in Setyawan, W., & Curtarolo, S. (2010). High-throughput electronic band structure calculations: Challenges and tools. Computational @@ -603,7 +626,7 @@ def get_conventional_standard_structure(self, international_monoclinic=True, kee key=lambda k: k["length"], ) - if latt_type in ("orthorhombic", "cubic"): + if latt_type in {"orthorhombic", "cubic"}: # you want to keep the c axis where it is # to keep the C- settings transf = np.zeros(shape=(3, 3)) @@ -611,7 +634,7 @@ def get_conventional_standard_structure(self, international_monoclinic=True, kee transf[2] = [0, 0, 1] a, b = sorted(lattice.abc[:2]) sorted_dic = sorted( - ({"vec": lattice.matrix[i], "length": lattice.abc[i], "orig_index": i} for i in [0, 1]), + ({"vec": lattice.matrix[i], "length": lattice.abc[i], "orig_index": i} for i in (0, 1)), key=lambda k: k["length"], ) for idx in range(2): @@ -623,7 +646,7 @@ def get_conventional_standard_structure(self, international_monoclinic=True, kee transf[2] = [1, 0, 0] a, b = sorted(lattice.abc[1:]) sorted_dic = sorted( - ({"vec": lattice.matrix[i], "length": lattice.abc[i], "orig_index": i} for i in [1, 2]), + ({"vec": lattice.matrix[i], "length": lattice.abc[i], "orig_index": i} for i in (1, 2)), key=lambda k: k["length"], ) for idx in range(2): @@ -647,7 +670,8 @@ def get_conventional_standard_structure(self, international_monoclinic=True, kee a, c = c, a transf = np.dot([[0, 0, 1], [0, 1, 0], [1, 0, 0]], transf) lattice = Lattice.tetragonal(a, c) - elif latt_type in ("hexagonal", "rhombohedral"): + + elif latt_type in {"hexagonal", "rhombohedral"}: # for the conventional cell representation, # we always show the rhombohedral lattices as hexagonal @@ -676,25 +700,25 @@ def get_conventional_standard_structure(self, international_monoclinic=True, kee transf = np.zeros(shape=(3, 3)) transf[2] = [0, 0, 1] sorted_dic = sorted( - ({"vec": lattice.matrix[i], "length": lattice.abc[i], "orig_index": i} for i in [0, 1]), + ({"vec": lattice.matrix[i], "length": lattice.abc[i], "orig_index": i} for i in (0, 1)), key=lambda k: k["length"], ) a = sorted_dic[0]["length"] b = sorted_dic[1]["length"] c = lattice.abc[2] new_matrix = None - for t in itertools.permutations(list(range(2)), 2): + for tp2 in itertools.permutations(list(range(2)), 2): m = lattice.matrix - latt2 = Lattice([m[t[0]], m[t[1]], m[2]]) + latt2 = Lattice([m[tp2[0]], m[tp2[1]], m[2]]) lengths = latt2.lengths angles = latt2.angles if angles[0] > 90: # if the angle is > 90 we invert a and b to get # an angle < 90 - a, b, c, alpha, beta, gamma = Lattice([-m[t[0]], -m[t[1]], m[2]]).parameters + a, b, c, alpha, beta, gamma = Lattice([-m[tp2[0]], -m[tp2[1]], m[2]]).parameters transf = np.zeros(shape=(3, 3)) - transf[0][t[0]] = -1 - transf[1][t[1]] = -1 + transf[0][tp2[0]] = -1 + transf[1][tp2[1]] = -1 transf[2][2] = 1 alpha = math.pi * alpha / 180 new_matrix = [ @@ -706,8 +730,8 @@ def get_conventional_standard_structure(self, international_monoclinic=True, kee if angles[0] < 90: transf = np.zeros(shape=(3, 3)) - transf[0][t[0]] = 1 - transf[1][t[1]] = 1 + transf[0][tp2[0]] = 1 + transf[1][tp2[1]] = 1 transf[2][2] = 1 a, b, c = lengths alpha = math.pi * angles[0] / 180 @@ -732,15 +756,15 @@ def get_conventional_standard_structure(self, international_monoclinic=True, kee # and b 90 and b < c: - a, b, c, alpha, beta, gamma = Lattice([-m[t[0]], -m[t[1]], m[t[2]]]).parameters + a, b, c, alpha, beta, gamma = Lattice([-m[tp3[0]], -m[tp3[1]], m[tp3[2]]]).parameters transf = np.zeros(shape=(3, 3)) - transf[0][t[0]] = -1 - transf[1][t[1]] = -1 - transf[2][t[2]] = 1 + transf[0][tp3[0]] = -1 + transf[1][tp3[1]] = -1 + transf[2][tp3[2]] = 1 alpha = math.pi * alpha / 180 new_matrix = [ [a, 0, 0], @@ -751,9 +775,9 @@ def get_conventional_standard_structure(self, international_monoclinic=True, kee if alpha < 90 and b < c: transf = np.zeros(shape=(3, 3)) - transf[0][t[0]] = 1 - transf[1][t[1]] = 1 - transf[2][t[2]] = 1 + transf[0][tp3[0]] = 1 + transf[1][tp3[1]] = 1 + transf[2][tp3[2]] = 1 alpha = math.pi * alpha / 180 new_matrix = [ [a, 0, 0], @@ -883,7 +907,7 @@ def is_all_acute_or_obtuse(matrix) -> bool: ) return new_struct.get_sorted_structure() - def get_kpoint_weights(self, kpoints, atol=1e-5): + def get_kpoint_weights(self, kpoints: Sequence[Kpoint], atol: float = 1e-5) -> list[float]: """Calculate the weights for a list of kpoints. Args: @@ -918,20 +942,20 @@ def get_kpoint_weights(self, kpoints, atol=1e-5): mapping = list(mapping) grid = (np.array(grid) + np.array(shift) * (0.5, 0.5, 0.5)) / mesh weights = [] - mapped = defaultdict(int) + mapped: dict[tuple, int] = defaultdict(int) for kpt in kpoints: for idx, g in enumerate(grid): if np.allclose(pbc_diff(kpt, g), (0, 0, 0), atol=atol): mapped[tuple(g)] += 1 weights.append(mapping.count(mapping[idx])) break - if (len(mapped) != len(set(mapping))) or (not all(v == 1 for v in mapped.values())): + if (len(mapped) != len(set(mapping))) or any(v != 1 for v in mapped.values()): raise ValueError("Unable to find 1:1 corresponding between input kpoints and irreducible grid!") return [w / sum(weights) for w in weights] def is_laue(self) -> bool: """Check if the point group of the structure has Laue symmetry (centrosymmetry).""" - laue = ("-1", "2/m", "mmm", "4/m", "4/mmm", "-3", "-3m", "6/m", "6/mmm", "m-3", "m-3m") + laue = {"-1", "2/m", "mmm", "4/m", "4/mmm", "-3", "-3m", "6/m", "6/mmm", "m-3", "m-3m"} return str(self.get_point_group_symbol()) in laue @@ -961,7 +985,13 @@ class PointGroupAnalyzer: inversion_op = SymmOp.inversion() - def __init__(self, mol, tolerance=0.3, eigen_tolerance=0.01, matrix_tolerance=0.1): + def __init__( + self, + mol: Molecule, + tolerance: float = 0.3, + eigen_tolerance: float = 0.01, + matrix_tolerance: float = 0.1, + ) -> None: """The default settings are usually sufficient. Args: @@ -979,10 +1009,10 @@ def __init__(self, mol, tolerance=0.3, eigen_tolerance=0.01, matrix_tolerance=0. self.eig_tol = eigen_tolerance self.mat_tol = matrix_tolerance self._analyze() - if self.sch_symbol in ["C1v", "C1h"]: - self.sch_symbol = "Cs" + if self.sch_symbol in {"C1v", "C1h"}: + self.sch_symbol: str = "Cs" - def _analyze(self): + def _analyze(self) -> None: if len(self.centered_mol) == 1: self.sch_symbol = "Kh" else: @@ -993,7 +1023,7 @@ def _analyze(self): wt = site.species.weight for i in range(3): inertia_tensor[i, i] += wt * (c[(i + 1) % 3] ** 2 + c[(i + 2) % 3] ** 2) - for i, j in [(0, 1), (1, 2), (0, 2)]: + for i, j in ((0, 1), (1, 2), (0, 2)): inertia_tensor[i, j] += -wt * c[i] * c[j] inertia_tensor[j, i] += -wt * c[j] * c[i] total_inertia += wt * np.dot(c, c) @@ -1010,8 +1040,8 @@ def _analyze(self): eig_all_same = abs(v1 - v2) < self.eig_tol and abs(v1 - v3) < self.eig_tol eig_all_diff = abs(v1 - v2) > self.eig_tol and abs(v1 - v3) > self.eig_tol and abs(v2 - v3) > self.eig_tol - self.rot_sym = [] - self.symmops = [SymmOp(np.eye(4))] + self.rot_sym: list = [] + self.symmops: list[SymmOp] = [SymmOp(np.eye(4))] if eig_zero: logger.debug("Linear molecule detected") self._proc_linear() @@ -1025,14 +1055,14 @@ def _analyze(self): logger.debug("Symmetric top molecule detected") self._proc_sym_top() - def _proc_linear(self): + def _proc_linear(self) -> None: if self.is_valid_op(PointGroupAnalyzer.inversion_op): self.sch_symbol = "D*h" self.symmops.append(PointGroupAnalyzer.inversion_op) else: self.sch_symbol = "C*v" - def _proc_asym_top(self): + def _proc_asym_top(self) -> None: """Handles asymmetric top molecules, which cannot contain rotational symmetry larger than 2. """ @@ -1047,7 +1077,7 @@ def _proc_asym_top(self): logger.debug("Cyclic group detected.") self._proc_cyclic() - def _proc_sym_top(self): + def _proc_sym_top(self) -> None: """Handles symmetric top molecules which has one unique eigenvalue whose corresponding principal axis is a unique rotational axis. @@ -1074,7 +1104,7 @@ def _proc_sym_top(self): else: self._proc_no_rot_sym() - def _proc_no_rot_sym(self): + def _proc_no_rot_sym(self) -> None: """Handles molecules with no rotational symmetry. Only possible point groups are C1, Cs and Ci. @@ -1090,7 +1120,7 @@ def _proc_no_rot_sym(self): self.sch_symbol = "Cs" break - def _proc_cyclic(self): + def _proc_cyclic(self) -> None: """Handles cyclic group molecules.""" main_axis, rot = max(self.rot_sym, key=lambda v: v[1]) self.sch_symbol = f"C{rot}" @@ -1102,7 +1132,7 @@ def _proc_cyclic(self): elif mirror_type == "" and self.is_valid_op(SymmOp.rotoreflection(main_axis, angle=180 / rot)): self.sch_symbol = f"S{2 * rot}" - def _proc_dihedral(self): + def _proc_dihedral(self) -> None: """Handles dihedral group molecules, i.e those with intersecting R2 axes and a main axis. """ @@ -1114,7 +1144,7 @@ def _proc_dihedral(self): elif mirror_type != "": self.sch_symbol += "d" - def _check_R2_axes_asym(self): + def _check_R2_axes_asym(self) -> None: """Test for 2-fold rotation along the principal axes. Used to handle asymmetric top molecules. @@ -1125,14 +1155,14 @@ def _check_R2_axes_asym(self): self.symmops.append(op) self.rot_sym.append((v, 2)) - def _find_mirror(self, axis): + def _find_mirror(self, axis: NDArray) -> Literal["h", "d", "v", ""]: """Looks for mirror symmetry of specified type about axis. Possible types are "h" or "vd". Horizontal (h) mirrors are perpendicular to the axis while vertical (v) or diagonal (d) mirrors are parallel. v mirrors has atoms lying on the mirror plane while d mirrors do not. """ - mirror_type = "" + mirror_type: Literal["h", "d", "v", ""] = "" # First test whether the axis itself is the normal to a mirror plane. if self.is_valid_op(SymmOp.reflection(axis)): @@ -1159,7 +1189,7 @@ def _find_mirror(self, axis): return mirror_type - def _get_smallest_set_not_on_axis(self, axis): + def _get_smallest_set_not_on_axis(self, axis: NDArray) -> list: """Get the smallest list of atoms with the same species and distance from origin AND does not lie on the specified axis. @@ -1168,8 +1198,7 @@ def _get_smallest_set_not_on_axis(self, axis): """ def not_on_axis(site): - v = np.cross(site.coords, axis) - return np.linalg.norm(v) > self.tol + return np.linalg.norm(np.cross(site.coords, axis)) > self.tol valid_sets = [] _origin_site, dist_el_sites = cluster_sites(self.centered_mol, self.tol) @@ -1180,7 +1209,7 @@ def not_on_axis(site): return min(valid_sets, key=len) - def _check_rot_sym(self, axis): + def _check_rot_sym(self, axis: NDArray) -> int: """Determine the rotational symmetry about supplied axis. Used only for symmetric top molecules which has possible rotational symmetry @@ -1188,18 +1217,17 @@ def _check_rot_sym(self, axis): """ min_set = self._get_smallest_set_not_on_axis(axis) max_sym = len(min_set) - for i in range(max_sym, 0, -1): - if max_sym % i != 0: + for idx in range(max_sym, 0, -1): + if max_sym % idx != 0: continue - op = SymmOp.from_axis_angle_and_translation(axis, 360 / i) - rotvalid = self.is_valid_op(op) - if rotvalid: + op = SymmOp.from_axis_angle_and_translation(axis, 360 / idx) + if self.is_valid_op(op): self.symmops.append(op) - self.rot_sym.append((axis, i)) - return i + self.rot_sym.append((axis, idx)) + return idx return 1 - def _check_perpendicular_r2_axis(self, axis): + def _check_perpendicular_r2_axis(self, axis: NDArray) -> None | Literal[True]: """Check for R2 axes perpendicular to unique axis. For handling symmetric top molecules. @@ -1209,14 +1237,13 @@ def _check_perpendicular_r2_axis(self, axis): test_axis = np.cross(s1.coords - s2.coords, axis) if np.linalg.norm(test_axis) > self.tol: op = SymmOp.from_axis_angle_and_translation(test_axis, 180) - r2present = self.is_valid_op(op) - if r2present: + if self.is_valid_op(op): self.symmops.append(op) self.rot_sym.append((test_axis, 2)) return True return None - def _proc_sph_top(self): + def _proc_sph_top(self) -> None: """Handles Spherical Top Molecules, which belongs to the T, O or I point groups. """ @@ -1228,22 +1255,24 @@ def _proc_sph_top(self): if rot < 3: logger.debug("Accidental spherical top!") self._proc_sym_top() + elif rot == 3: mirror_type = self._find_mirror(main_axis) - if mirror_type != "": - if self.is_valid_op(PointGroupAnalyzer.inversion_op): - self.symmops.append(PointGroupAnalyzer.inversion_op) - self.sch_symbol = "Th" - else: - self.sch_symbol = "Td" - else: + if mirror_type == "": self.sch_symbol = "T" + elif self.is_valid_op(PointGroupAnalyzer.inversion_op): + self.symmops.append(PointGroupAnalyzer.inversion_op) + self.sch_symbol = "Th" + else: + self.sch_symbol = "Td" + elif rot == 4: if self.is_valid_op(PointGroupAnalyzer.inversion_op): self.symmops.append(PointGroupAnalyzer.inversion_op) self.sch_symbol = "Oh" else: self.sch_symbol = "O" + elif rot == 5: if self.is_valid_op(PointGroupAnalyzer.inversion_op): self.symmops.append(PointGroupAnalyzer.inversion_op) @@ -1251,14 +1280,14 @@ def _proc_sph_top(self): else: self.sch_symbol = "I" - def _find_spherical_axes(self): + def _find_spherical_axes(self) -> None: """Looks for R5, R4, R3 and R2 axes in spherical top molecules. Point group T molecules have only one unique 3-fold and one unique 2-fold axis. O molecules have one unique 4, 3 and 2-fold axes. I molecules have a unique 5-fold axis. """ - rot_present = defaultdict(bool) + rot_present: dict[int, bool] = defaultdict(bool) _origin_site, dist_el_sites = cluster_sites(self.centered_mol, self.tol) test_set = min(dist_el_sites.values(), key=len) coords = [s.coords for s in test_set] @@ -1286,21 +1315,20 @@ def _find_spherical_axes(self): if rot_present[2] and rot_present[3] and (rot_present[4] or rot_present[5]): break - def get_pointgroup(self): + def get_pointgroup(self) -> PointGroupOperations: """Get a PointGroup object for the molecule.""" return PointGroupOperations(self.sch_symbol, self.symmops, self.mat_tol) - def get_symmetry_operations(self): - """Return symmetry operations as a list of SymmOp objects. Returns Cartesian coord - symmops. + def get_symmetry_operations(self) -> Sequence[SymmOp]: + """Get symmetry operations. Returns: - list[SymmOp]: symmetry operations. + list[SymmOp]: symmetry operations in Cartesian coord. """ return generate_full_symmops(self.symmops, self.tol) - def get_rotational_symmetry_number(self): - """Return the rotational symmetry number.""" + def get_rotational_symmetry_number(self) -> int: + """Get the rotational symmetry number.""" symm_ops = self.get_symmetry_operations() symm_number = 0 for symm in symm_ops: @@ -1309,7 +1337,7 @@ def get_rotational_symmetry_number(self): symm_number += 1 return symm_number - def is_valid_op(self, symm_op) -> bool: + def is_valid_op(self, symm_op: SymmOp) -> bool: """Check if a particular symmetry operation is a valid symmetry operation for a molecule, i.e., the operation maps all atoms to another equivalent atom. @@ -1323,11 +1351,11 @@ def is_valid_op(self, symm_op) -> bool: for site in self.centered_mol: coord = symm_op.operate(site.coords) ind = find_in_coord_list(coords, coord, self.tol) - if not (len(ind) == 1 and self.centered_mol[ind[0]].species == site.species): + if len(ind) != 1 or self.centered_mol[ind[0]].species != site.species: return False return True - def _get_eq_sets(self): + def _get_eq_sets(self) -> dict[Literal["eq_sets", "sym_ops"], Any]: """Calculate the dictionary for mapping equivalent atoms onto each other. Returns: @@ -1338,7 +1366,8 @@ def _get_eq_sets(self): operation that maps atom i unto j. """ UNIT = np.eye(3) - eq_sets, operations = defaultdict(set), defaultdict(dict) + eq_sets: dict[int, set] = defaultdict(set) + operations: dict[int, dict] = defaultdict(dict) symm_ops = [op.rotation_matrix for op in generate_full_symmops(self.symmops, self.tol)] def get_clustered_indices(): @@ -1372,7 +1401,7 @@ def get_clustered_indices(): return {"eq_sets": eq_sets, "sym_ops": operations} @staticmethod - def _combine_eq_sets(equiv_sets, sym_ops): + def _combine_eq_sets(equiv_sets: dict, sym_ops: dict) -> dict: """Combines the dicts of _get_equivalent_atom_dicts into one. Args: @@ -1433,7 +1462,7 @@ def get_equivalent_atoms(self): eq = self._get_eq_sets() return self._combine_eq_sets(eq["eq_sets"], eq["sym_ops"]) - def symmetrize_molecule(self): + def symmetrize_molecule(self) -> dict: """Get a symmetrized molecule. The equivalent atoms obtained via @@ -1468,7 +1497,12 @@ def symmetrize_molecule(self): return {"sym_mol": molecule, "eq_sets": eq_sets, "sym_ops": ops} -def iterative_symmetrize(mol, max_n=10, tolerance=0.3, epsilon=1e-2): +def iterative_symmetrize( + mol: Molecule, + max_n: int = 10, + tolerance: float = 0.3, + epsilon: float = 1e-2, +) -> dict: """Get a symmetrized molecule. The equivalent atoms obtained via @@ -1489,7 +1523,6 @@ def iterative_symmetrize(mol, max_n=10, tolerance=0.3, epsilon=1e-2): subsequently symmetrized structures is smaller epsilon, the iteration stops before max_n is reached. - Returns: dict: with three possible keys: sym_mol: A symmetrized molecule instance. @@ -1512,7 +1545,11 @@ def iterative_symmetrize(mol, max_n=10, tolerance=0.3, epsilon=1e-2): return eq -def cluster_sites(mol: Molecule, tol: float, give_only_index: bool = False) -> tuple[Site | None, dict]: +def cluster_sites( + mol: Molecule, + tol: float, + give_only_index: bool = False, +) -> tuple[Site | None, dict]: """Cluster sites based on distance and species type. Args: @@ -1547,7 +1584,10 @@ def cluster_sites(mol: Molecule, tol: float, give_only_index: bool = False) -> t return origin_site, clustered_sites -def generate_full_symmops(symmops: Sequence[SymmOp], tol: float) -> Sequence[SymmOp]: +def generate_full_symmops( + symmops: Sequence[SymmOp], + tol: float, +) -> Sequence[SymmOp]: """Recursive algorithm to permute through all possible combinations of the initially supplied symmetry operations to arrive at a complete set of operations mapping a single atom to all other equivalent atoms in the point group. This assumes that the @@ -1592,7 +1632,12 @@ def generate_full_symmops(symmops: Sequence[SymmOp], tol: float) -> Sequence[Sym class SpacegroupOperations(list): """Represents a space group, which is a collection of symmetry operations.""" - def __init__(self, int_symbol, int_number, symmops): + def __init__( + self, + int_symbol: str, + int_number: int, + symmops: Sequence[SymmOp], + ) -> None: """ Args: int_symbol (str): International symbol of the spacegroup. @@ -1604,7 +1649,15 @@ def __init__(self, int_symbol, int_number, symmops): self.int_number = int_number super().__init__(symmops) - def are_symmetrically_equivalent(self, sites1, sites2, symm_prec=1e-3) -> bool: + def __str__(self) -> str: + return f"{self.int_symbol} ({self.int_number}) spacegroup" + + def are_symmetrically_equivalent( + self, + sites1: set[PeriodicSite], + sites2: set[PeriodicSite], + symm_prec: float = 1e-3, + ) -> bool: """Given two sets of PeriodicSites, test if they are actually symmetrically equivalent under this space group. Useful, for example, if you want to test if selecting atoms 1 and 2 out of a set of 4 atoms are symmetrically the same as @@ -1635,9 +1688,6 @@ def in_sites(site): return True return False - def __str__(self): - return f"{self.int_symbol} ({self.int_number}) spacegroup" - class PointGroupOperations(list): """Represents a point group, which is a sequence of symmetry operations. @@ -1646,7 +1696,12 @@ class PointGroupOperations(list): sch_symbol (str): Schoenflies symbol of the point group. """ - def __init__(self, sch_symbol, operations, tol: float = 0.1): + def __init__( + self, + sch_symbol: str, + operations: Sequence[SymmOp], + tol: float = 0.1, + ) -> None: """ Args: sch_symbol (str): Schoenflies symbol of the point group. @@ -1659,5 +1714,5 @@ def __init__(self, sch_symbol, operations, tol: float = 0.1): self.sch_symbol = sch_symbol super().__init__(generate_full_symmops(operations, tol)) - def __repr__(self): + def __repr__(self) -> str: return self.sch_symbol diff --git a/tests/symmetry/test_analyzer.py b/tests/symmetry/test_analyzer.py index 84594adc2c7..41c7b72e6c1 100644 --- a/tests/symmetry/test_analyzer.py +++ b/tests/symmetry/test_analyzer.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import asdict from unittest import TestCase import numpy as np @@ -17,6 +18,7 @@ from pymatgen.symmetry.structure import SymmetrizedStructure from pymatgen.util.testing import TEST_FILES_DIR, VASP_IN_DIR, VASP_OUT_DIR, PymatgenTest from pytest import approx, raises +from spglib import SpglibDataset TEST_DIR = f"{TEST_FILES_DIR}/symmetry/analyzer" @@ -123,7 +125,7 @@ def test_get_point_group_operations_uniq(self): def test_get_symmetry_dataset(self): ds = self.sg.get_symmetry_dataset() - assert ds["international"] == "Pnma" + assert ds.international == "Pnma" def test_init_cell(self): # see https://github.com/materialsproject/pymatgen/pull/3179 @@ -170,12 +172,12 @@ def test_get_crystal_system(self): assert crystal_system == "orthorhombic" assert self.disordered_sg.get_crystal_system() == "tetragonal" - # orig_spg = self.sg._space_group_data["number"] - # self.sg._space_group_data["number"] = 0 - # with pytest.raises(ValueError, match="Received invalid space group 0"): - # self.sg.get_crystal_system() - # - # self.sg._space_group_data["number"] = orig_spg + def test_invalid_space_group_number(self): + invalid_sg = asdict(self.sg.get_symmetry_dataset()) + invalid_sg["number"] = 0 + self.sg._space_group_data = SpglibDataset(**invalid_sg) + with pytest.raises(ValueError, match="Received invalid space group 0"): + self.sg.get_crystal_system() def test_get_refined_structure(self): for pg_analyzer in self.sg.get_refined_structure().lattice.angles: