diff --git a/smol/cofe/space/clusterspace.py b/smol/cofe/space/clusterspace.py index d6b5439d6..11b3e8ac4 100644 --- a/smol/cofe/space/clusterspace.py +++ b/smol/cofe/space/clusterspace.py @@ -4,7 +4,6 @@ from importlib import import_module import warnings import numpy as np -from itertools import combinations from monty.json import MSONable from pymatgen.core import Structure, PeriodicSite from pymatgen.symmetry.analyzer import SpacegroupAnalyzer, SymmOp @@ -14,7 +13,7 @@ lattice_points_in_supercell, coord_list_mapping_pbc from src.mc_utils import corr_from_occupancy from smol.cofe.space import Orbit, basis_factory, get_site_spaces, \ - get_allowed_species, Vacancy, Cluster + get_allowed_species, Vacancy from smol.exceptions import SymmetryError, StructureMatchError, \ SYMMETRY_ERROR_MESSAGE from smol.cofe.space.constants import SITE_TOL @@ -157,10 +156,6 @@ def __init__(self, structure, expansion_structure, symops, orbits, # already been matched self._supercell_orb_inds = {} - # 2D lists to store 1-level-down hierarchy info. (One level only!) - # Will be cleaned after any change of orbits! - self._bit_combo_hierarchy = None - # assign the cluster ids self._assign_orbit_ids() @@ -351,53 +346,6 @@ def function_total_multiplicities(self): return self.orbit_multiplicities[self.function_orbit_ids] * \ self.function_ordering_multiplicities - def bit_combo_hierarchy(self, min_size=2, invert=False): - """Get 1-level-down hierarchy of correlation functions. - - The size difference between the current corr function and its - sub-clusters is only 1! We only give one-level down hierarchy. Because - that is enough to contrain hierarchy! - - Note: Since complete, all level hierarchy table is not practical - for CE fit, we will not include it as an attribute of this class. - If you still want to see it, call function get_complete_mapping - in this module. - - Args: - min_size (int): optional - Minimum size required for the correlation function. If the size - of the correlation function is smaller or equals to min_size, - will not search for its sub-clusters. For hierarchy - constraints, the recommended setting is 2. - invert (bool): optional - Default is invert=False which gives the high to low bit combo - hierarchy. Invert= True will invert the hierarchy into low to - high - - Returns: - list of lists: Each sublist is of of length self.num_corr_function - and contains integer indices of correlation functions that are - contained by the correlation function with the current index. - """ - if self._bit_combo_hierarchy is None: - self._bit_combo_hierarchy = self._get_hierarchy_up_to_low() - - # array of orbit sizes for each bit id. - all_sizes = np.array([0] + [self.orbits[i - 1].base_cluster.size - for i in self.function_orbit_ids[1:]]) - - # Stores a hierarchy from cluster size 1, then retrieves from min_size. - hierarchy = [] - for vals, size in zip(self._bit_combo_hierarchy, all_sizes): - if size <= min_size: - hierarchy.append([]) - else: - hierarchy.append(vals) - if invert: - return invert_mapping_table(hierarchy) - else: - return hierarchy - @property def basis_orthogonal(self): """Check if the orbit basis defined is orthogonal.""" @@ -417,6 +365,57 @@ def external_terms(self): """ return self._external_terms + def orbit_hierarchy(self, level=1, min_size=1): + """Get orbit hierarchy by ids. + + The empty/constant cluster index 0 is technically a suborbit of all + orbits, but is not added to the hierarchy entries. + + Args: + level (int): + min_size (int): + + Returns: + list of list: each element of the inner lists is the orbit id for + all suborbits corresponding to the orbit at the given outer list + index. + """ + sub_ids = [ + [suborb.id for suborb in self.get_sub_orbits( + orb.id, level=level, min_size=min_size)] + for orb in self.orbits + ] + + return [[], ] + sub_ids + + def function_hierarchy(self, level=1, min_size=2, invert=False): + """Get the correlation function hierarchy. + + The function hierarchy is t + + Args: + level (int): + min_size (int): + invert (bool): optional + Default is invert=False which gives the high to low bit combo + hierarchy. Invert= True will invert the hierarchy into low to + high + + Returns: + list of list: each element of the inner lists is the bit id for + all corr functions corresponding to the corr function at the given + outer list index. + """ + hierarchy = [ + self.get_sub_function_ids(i, level=level, min_size=min_size) + for i in range(self.num_corr_functions) + ] + + if invert: + hierarchy = invert_mapping(hierarchy) + + return hierarchy + def orbits_from_cutoffs(self, upper, lower=0): """Get orbits with clusters within given diameter cutoffs (inclusive). @@ -760,9 +759,8 @@ def remove_orbits(self, orbit_ids): if orbit.id not in orbit_ids] self._assign_orbit_ids() # Re-assign ids - # Clear the cached supercell orbit mappings and hierarchy + # Clear the cached supercell orbit mappings self._supercell_orb_inds = {} - self._bit_combo_hierarchy = None def remove_orbit_bit_combos(self, bit_ids): """Remove orbit bit combos by their ids. @@ -807,9 +805,6 @@ def remove_orbit_bit_combos(self, bit_ids): else: self._assign_orbit_ids() # Re-assign ids - # clear hierarchy - self._bit_combo_hierarchy = None - def copy(self): """Deep copy of instance.""" return deepcopy(self) @@ -834,6 +829,76 @@ def structure_site_mapping(self, supercell, structure): 'structure.') return mapping.tolist() + def get_sub_orbits(self, orbit_id, level=1, min_size=1): + """Get sub orbits of the orbit for the corresponding orbit_id. + + Args: + orbit_id (int): + id of orbit to get sub orbit id for + level (int): optional + how many levels down to look for suborbits. If all suborbits + are needed make level large enough or set to None. + min_size (int): optional + minimum size of clusters in sub orbits to include + + Returns: + list of ints: list containing ids of suborbits + """ + if orbit_id == 0: + return [] + size = self.orbits[orbit_id - 1].base_cluster.size + if level is None or level < 0 or size - level - 1 < 0: + stop = 0 + elif min_size > size - level: + stop = min_size - 1 + else: + stop = size - level - 1 + + search_sizes = range(size - 1, stop, -1) + return [orbit for s in search_sizes for orbit in self._orbits[s] + if self.orbits[orbit_id - 1].is_sub_orbit(orbit)] + + def get_sub_function_ids(self, corr_id, level=1, min_size=1): + """Get the bit combo ids of all sub correlation functions. + + A sub correlation function of a given correlation function means that + the sub correlation fucntion is a factor of the correlation function + (with the additional requirement of acting over the sites in sub + clusters of the clusters over which the given corr function acts on). + + In other works think of it an orbit of function labeled subclusters + of a given orbit of function labeled clusters...a mouthful... + + Args: + corr_id (int): + id of orbit to get sub orbit id for + level (int): optional + how many levels down to look for suborbits. If all suborbits + are needed make level large enough or set to None. + min_size (int): optional + minimum size of clusters in sub orbits to include + + Returns: + list of ints: list containing ids of sub correlation functions + """ + if corr_id == 0: + return [] + + orbit = self.orbits[self.function_orbit_ids[corr_id] - 1] + bit_combo = orbit.bit_combos[corr_id - orbit.bit_id] + + sub_fun_ids = [] + for sub_orbit in self.get_sub_orbits(orbit.id, level=level, + min_size=min_size): + inds = orbit.sub_orbit_mappings(sub_orbit) + for i, sub_bit_combo in enumerate(sub_orbit.bit_combos): + if np.any( + np.all( + sub_bit_combo[0] == bit_combo[:, inds], axis=2)): + sub_fun_ids.append(sub_orbit.bit_id + i) + + return sub_fun_ids + def _assign_orbit_ids(self): """Assign unique id's to orbit. @@ -971,9 +1036,9 @@ def _gen_orbit_indices(self, scmatrix): # coordinate index] tcoords = fcoords[:, None, :, :] + ts[None, :, None, :] tcs = tcoords.shape - inds = coord_list_mapping_pbc(tcoords.reshape((-1, 3)), - supercell_fcoords, - atol=SITE_TOL).reshape((tcs[0] * tcs[1], tcs[2])) # noqa + inds = coord_list_mapping_pbc( + tcoords.reshape((-1, 3)), supercell_fcoords, + atol=SITE_TOL).reshape((tcs[0] * tcs[1], tcs[2])) # noqa # orbit_ids holds orbit, and 2d array of index groups that # correspond to the orbit # the 2d array may have some duplicates. This is due to @@ -985,92 +1050,6 @@ def _gen_orbit_indices(self, scmatrix): return orbit_indices - def _find_sub_cluster(self, bit_id, min_size=1): - """Find 1-level-down subclusters of a given correlation function. - - Args: - bit_id (int): - Index of the correlation function to find subclusters with. - min_size (int): optional - Minimum size required for the correlation function. If the - size of the correlation function is smaller or equals to - min_size, will not search for its sub-clusters. - - Returns: - list: A list of integer indices specifying which correlation - functions are 1-level-down subclusters of the given correlation - function. - """ - if bit_id == 0: # Constant term - return [] - - # Separate the zero term out. - all_sizes = np.array([0] + [self.orbits[i - 1].base_cluster.size - for i in self.function_orbit_ids[1:]]) - - bit_combos = self.all_bit_combos[bit_id] - orbit = self.orbits[self.function_orbit_ids[bit_id] - 1] - sites = orbit.base_cluster.sites - - size = all_sizes[bit_id] - if size <= min_size: - return [] - - possible_sub_ids = np.where(all_sizes == (size - 1))[0] - lattice = self._exp_structure.lattice - sub_indices = [] - for comb in combinations(np.arange(size), size - 1): - sub_sites = np.array(sites[np.array(comb), :]) - sub_bit_combos = np.array(bit_combos[:, np.array(comb)]) - sub_cluster = Cluster(sub_sites, lattice) - sub_equiv = [sub_cluster] - for symop in self.symops: - new_sites = symop.operate_multi(sub_sites) - c = Cluster(new_sites, lattice) - if c not in sub_equiv: - sub_equiv.append(c) - - for sub_id in possible_sub_ids: - sub_bit_combo = self.all_bit_combos[sub_id] - sub_orbit = self.orbits[self.function_orbit_ids[sub_id] - 1] - cluster_match = sub_orbit.base_cluster in sub_equiv - bit_match = np.any( - np.all(sub_bit_combo[0] == sub_bit_combos, axis=1)) - - if cluster_match and bit_match: - if sub_id in sub_indices: - continue - else: - sub_indices.append(sub_id) - - return sub_indices - - def _get_hierarchy_up_to_low(self, min_size=1): - """Generate high-to-low hierarchy. - - The size difference between the current corr function and its - sub clusters is only 1! We only give one-level down hierarchy - Because it would be enough to contrain hierarchy! - - Args: - min_size: - Minimum size required for the correlation function. If the - size of the correlation function is smaller or equals to - min_size, will not search for its sub-clusters. - - Returns: - list of lists: Each sublist of length self.num_corr_function - contains integer indices of correlation functions that are - contained by the correlation function with the current index. - """ - up_low_hierarchy = [[] for i in range(self.num_corr_functions)] - - for ii in np.flip(np.arange(self.num_corr_functions)): - sub_indices = self._find_sub_cluster(ii, min_size=min_size) - up_low_hierarchy[ii] = sub_indices - - return up_low_hierarchy - def __eq__(self, other): """Check equality between cluster subspaces.""" if not isinstance(other, ClusterSubspace): @@ -1143,7 +1122,6 @@ def from_dict(cls, d): np.array(ind)) for o_id, ind in orb_inds] cs._supercell_orb_inds = _supercell_orb_inds - cs._bit_combo_hierarchy = d.get('_bc_hierarchy') return cs def as_dict(self): @@ -1167,12 +1145,11 @@ def as_dict(self): 'sc_matcher': self._sc_matcher.as_dict(), 'site_matcher': self._site_matcher.as_dict(), 'external_terms': [et.as_dict() for et in self.external_terms], - '_supercell_orb_inds': _supercell_orb_inds, - '_bc_hierarchy': self._bit_combo_hierarchy} + '_supercell_orb_inds': _supercell_orb_inds} return d -def invert_mapping_table(mapping): +def invert_mapping(mapping): """Invert a mapping table from forward to backward, vice versa. Args: diff --git a/smol/cofe/space/orbit.py b/smol/cofe/space/orbit.py index 668202774..c352c509d 100644 --- a/smol/cofe/space/orbit.py +++ b/smol/cofe/space/orbit.py @@ -6,13 +6,13 @@ __author__ = "Luis Barroso-Luque, William Davidson Richard" -from itertools import chain, product, accumulate +from itertools import chain, product, accumulate, combinations import numpy as np from monty.json import MSONable from pymatgen.core import Lattice from pymatgen.core.operations import SymmOp -from pymatgen.util.coord import coord_list_mapping +from pymatgen.util.coord import coord_list_mapping, is_coord_subset from smol.utils import _repr from smol.exceptions import SymmetryError, SYMMETRY_ERROR_MESSAGE @@ -97,7 +97,6 @@ def __init__(self, sites, lattice, bits, site_bases, structure_symops): # Create basecluster self.base_cluster = Cluster(sites, lattice) - self.lattice = lattice @property def basis_type(self): @@ -136,6 +135,11 @@ def bit_combos(self): np.array(c, dtype=np.int_) for c in all_combos) return self._bit_combos + @property + def site_spaces(self): + """Get the site spaces for the site basis associate with each site.""" + return [site_basis.site_space for site_basis in self.site_bases] + @property def bit_combo_array(self): """Single array of all bit combos.""" @@ -164,7 +168,7 @@ def clusters(self): equiv = [self.base_cluster] for symop in self.structure_symops: new_sites = symop.operate_multi(self.base_cluster.sites) - c = Cluster(new_sites, self.lattice) + c = Cluster(new_sites, self.base_cluster.lattice) if c not in equiv: equiv.append(c) self._equiv = equiv @@ -184,18 +188,22 @@ def cluster_symops(self): """ if self._symops: return self._symops + self._symops = [] for symop in self.structure_symops: new_sites = symop.operate_multi(self.base_cluster.sites) - c = Cluster(new_sites, self.base_cluster.lattice) - if c == self.base_cluster: - recenter = np.round(self.base_cluster.centroid - c.centroid) - c_sites = c.sites + recenter - mapping = tuple(coord_list_mapping(self.base_cluster.sites, - c_sites, atol=SITE_TOL)) + cluster = Cluster(new_sites, self.base_cluster.lattice) + if cluster == self.base_cluster: + recenter = np.round( + self.base_cluster.centroid - cluster.centroid) + c_sites = cluster.sites + recenter + mapping = tuple(coord_list_mapping( + self.base_cluster.sites, c_sites, atol=SITE_TOL)) self._symops.append((symop, mapping)) + if len(self._symops) * self.multiplicity != len(self.structure_symops): raise SymmetryError(SYMMETRY_ERROR_MESSAGE) + return self._symops @property @@ -328,6 +336,66 @@ def assign_ids(self, orbit_id, orbit_bit_id, start_cluster_id): c_id = c.assign_ids(c_id) return orbit_id + 1, orbit_bit_id + len(self.bit_combos), c_id + def is_sub_orbit(self, orbit): + """Check if given orbits clusters are subclusters. + + Note this does not consider bit_combos + Args: + orbit (Orbit): + Orbit object to check if + Returns: + bool: True if the clusters of given orbit are subclusters. + """ + if self.base_cluster.size <= orbit.base_cluster.size: + return False + elif not np.all(sp in self.site_spaces for sp in orbit.site_spaces): + return False + + match = any( + Cluster(self.base_cluster.sites[inds, :], + self.base_cluster.lattice) + in orbit.clusters + for inds in combinations( + range(self.base_cluster.size), orbit.base_cluster.size)) + + return match + + def sub_orbit_mappings(self, orbit): + """Return a mapping of the sites in the orbit to a sub orbit. + + If the given orbit is not a sub-orbit will return an empty list. + Note this works for mapping between sites, sites spaces, and basis + functions associated with each site. + + Args: + orbit (Orbit): + A sub orbit to return mapping of sites + Returns: + list: of indices sucht that + self.base_cluster.sites[indices] = orbit.base_cluster.sites + """ + indsets = np.array(list(combinations( + (i for i, space in enumerate(self.site_spaces) + if space in orbit.site_spaces), len(orbit.site_spaces)))) + + mappings = [] + for cluster in self.clusters: + for inds in indsets: + # take the centroid of subset of sites, not all cluster sites + centroid = np.average(cluster.sites[inds], axis=0) + recenter = np.round(centroid - orbit.base_cluster.centroid) + c_sites = orbit.base_cluster.sites + recenter + if is_coord_subset(c_sites, cluster.sites): + mappings.append( + coord_list_mapping( + c_sites, cluster.sites, atol=SITE_TOL)) + + if len(mappings) == 0 and self.is_sub_orbit(orbit): + raise RuntimeError( + "The given orbit is a suborbit, but no site mappings were " + "found!\n Something is very wrong here!") + return np.unique(mappings, axis=0) + def __len__(self): """Get total number of orbit basis functions. @@ -351,15 +419,16 @@ def __str__(self): return f'[Orbit] id: {self.id:<3}' \ f'orderings: {len(self):<4}' \ f'multiplicity: {self.multiplicity:<4}' \ - f' no. symops: {len(self.cluster_symops):<4}\n' \ - f' {str(self.base_cluster)}' + f' no. symops: {len(self.cluster_symops):<4}\n'\ + f' {self.site_spaces}\n' \ + f' {str(self.base_cluster)}' def __repr__(self): """Get Orbit representation.""" return _repr(self, orb_id=self.id, orb_b_id=self.bit_id, radius=self.base_cluster.radius, - lattice=self.lattice, + lattice=self.base_cluster.lattice, basecluster=self.base_cluster) @classmethod @@ -387,7 +456,7 @@ def as_dict(self): d = {"@module": self.__class__.__module__, "@class": self.__class__.__name__, "sites": self.base_cluster.sites.tolist(), - "lattice": self.lattice.as_dict(), + "lattice": self.base_cluster.lattice.as_dict(), "bits": self.bits, "site_bases": [sb.as_dict() for sb in self.site_bases], "structure_symops": [so.as_dict() for so in diff --git a/tests/test_cofe/test_clusterspace.py b/tests/test_cofe/test_clusterspace.py index ce2329b4b..2c44af773 100644 --- a/tests/test_cofe/test_clusterspace.py +++ b/tests/test_cofe/test_clusterspace.py @@ -9,7 +9,7 @@ from pymatgen.util.coord import is_coord_subset_pbc from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from smol.cofe import ClusterSubspace -from smol.cofe.space.clusterspace import invert_mapping_table,\ +from smol.cofe.space.clusterspace import invert_mapping,\ get_complete_mapping from smol.cofe.extern import EwaldTerm from smol.cofe.space.constants import SITE_TOL @@ -19,23 +19,27 @@ # TODO test correlations for ternary and for applications of symops to structure + + def test_invert_mapping_table(): - forward = [[],[],[1],[1],[1],[2,4],[3,4],[2,3],[5,6,7]] - backward = [[],[2,3,4],[5,7],[6,7],[5,6],[8],[8],[8],[]] + forward = [[], [], [1], [1], [1], [2, 4], [3, 4], [2, 3], [5, 6, 7]] + backward = [[], [2, 3, 4], [5, 7], [6, 7], [5, 6], [8], [8], [8], []] - forward_invert = [sorted(sub) for sub in invert_mapping_table(forward)] - backward_invert = [sorted(sub) for sub in invert_mapping_table(backward)] + forward_invert = [sorted(sub) for sub in invert_mapping(forward)] + backward_invert = [sorted(sub) for sub in invert_mapping(backward)] assert forward_invert == backward assert backward_invert == forward def test_get_complete_mapping(): - forward = [[],[],[1],[1],[1],[2,4],[3,4],[2,3],[5,6,7]] - backward = [[],[2,3,4],[5,7],[6,7],[5,6],[8],[8],[8],[]] + forward = [[], [], [1], [1], [1], [2, 4], [3, 4], [2, 3], [5, 6, 7]] + backward = [[], [2, 3, 4], [5, 7], [6, 7],[5, 6], [8], [8], [8], []] - forward_full = [[],[],[1],[1],[1],[1,2,4],[1,3,4],[1,2,3],[1,2,3,4,5,6,7]] - backward_full = [[],[2,3,4,5,6,7,8],[5,7,8],[6,7,8],[5,6,8],[8],[8],[8],[]] + forward_full = [[], [], [1], [1], [1], [1, 2, 4], [1, 3, 4], [1, 2, 3], + [1, 2, 3, 4, 5, 6, 7]] + backward_full = [[], [2, 3, 4, 5, 6, 7, 8], [5, 7, 8], [6, 7, 8], + [5, 6, 8], [8], [8], [8], []] forward_comp = [sorted(sub) for sub in get_complete_mapping(forward)] backward_comp = [sorted(sub) for sub in get_complete_mapping(backward)] @@ -76,16 +80,25 @@ def setUp(self) -> None: supercell_size='volume') self.domains = get_allowed_species(self.structure) - def test_hierarchy(self): - hierarchy_uplow = self.cs.bit_combo_hierarchy() - self.assertEqual(sorted(hierarchy_uplow[0]), []) - self.assertEqual(sorted(hierarchy_uplow[-1]), [17,21]) - self.assertEqual(sorted(hierarchy_uplow[15]), []) - self.assertEqual(sorted(hierarchy_uplow[35]), [5, 6, 7, 10]) - self.assertEqual(sorted(hierarchy_uplow[55]), [6, 7, 8, 13]) - self.assertEqual(sorted(hierarchy_uplow[75]), [7, 16, 21]) - self.assertEqual(sorted(hierarchy_uplow[95]), [9, 19]) - self.assertEqual(sorted(hierarchy_uplow[115]), [13, 19, 21]) + def test_function_hierarchy(self): + hierarchy = self.cs.function_hierarchy() + self.assertEqual(sorted(hierarchy[0]), []) + self.assertEqual(sorted(hierarchy[-1]), [17, 21]) + self.assertEqual(sorted(hierarchy[15]), []) + self.assertEqual(sorted(hierarchy[35]), [5, 7, 10]) + self.assertEqual(sorted(hierarchy[55]), [6, 8, 13]) + self.assertEqual(sorted(hierarchy[75]), [7, 16, 21]) + self.assertEqual(sorted(hierarchy[95]), [9, 19]) + self.assertEqual(sorted(hierarchy[115]), [13, 19, 21]) + + def test_orbit_hierarchy(self): + hierarchy = self.cs.orbit_hierarchy() + self.assertEqual(sorted(hierarchy[0]), []) # empty + self.assertEqual(sorted(hierarchy[1]), []) # point + self.assertEqual(sorted(hierarchy[3]), [1, 2]) # distinct site pair + self.assertEqual(sorted(hierarchy[4]), [1]) # same site pair + self.assertEqual(sorted(hierarchy[15]), [3, 5]) # triplet + self.assertEqual(sorted(hierarchy[-1]), [6, 7]) def test_numbers(self): # Test the total generated orbits, orderings and clusters are diff --git a/tests/test_cofe/test_orbit.py b/tests/test_cofe/test_orbit.py index 89dc53953..f1398de60 100644 --- a/tests/test_cofe/test_orbit.py +++ b/tests/test_cofe/test_orbit.py @@ -1,5 +1,5 @@ import unittest -from itertools import combinations_with_replacement +from itertools import combinations_with_replacement, combinations import json import numpy as np from pymatgen.core import Lattice, Structure, Composition @@ -15,14 +15,17 @@ def setUp(self) -> None: species = [{'Li': 0.1, 'Ca': 0.1}] * 3 + ['Br'] self.coords = ((0.25, 0.25, 0.25), (0.75, 0.75, 0.75), (0.5, 0.5, 0.5), (0, 0, 0)) - structure = Structure(self.lattice, species, self.coords) - sf = SpacegroupAnalyzer(structure) + self.structure = Structure(self.lattice, species, self.coords) + sf = SpacegroupAnalyzer(self.structure) self.symops = sf.get_symmetry_operations() - self.spaces = [SiteSpace(Composition({'Li': 1.0 / 3.0, 'Ca': 1.0 / 3.0})), - SiteSpace(Composition({'Li': 1.0 / 3.0, 'Ca': 1.0 / 3.0}))] + self.spaces = [ + SiteSpace(Composition({'Li': 1.0 / 3.0, 'Ca': 1.0 / 3.0})), + SiteSpace(Composition({'Li': 1.0 / 3.0, 'Ca': 1.0 / 3.0})), + SiteSpace(Composition({'Li': 1.0 / 3.0, 'Ca': 1.0 / 3.0}))] self.bases = [basis_factory('indicator', bit) for bit in self.spaces] - self.basecluster = Cluster(self.coords[:2], self.lattice) - self.orbit = Orbit(self.coords[:2], self.lattice, [[0, 1], [0, 1]], + self.basecluster = Cluster(self.coords[:3], self.lattice) + self.orbit = Orbit(self.coords[:3], self.lattice, + [[0, 1], [0, 1], [0, 1]], self.bases, self.symops) self.orbit.assign_ids(1, 1, 1) @@ -48,18 +51,54 @@ def test_cluster_symops(self): self.assertEqual(len(self.orbit.cluster_symops), 12) def test_eq(self): - orbit1 = Orbit(self.coords[:2], self.lattice, [[0, 1], [0, 1]], + orbit1 = Orbit(self.coords[:3], self.lattice, [[0, 1], [0, 1], [0, 1]], self.bases, self.symops) - orbit2 = Orbit(self.coords[:3], self.lattice, [[0, 1], [0, 1], [0, 1]], - self.bases + [None], self.symops) + orbit2 = Orbit(self.coords[:2], self.lattice, [[0, 1], [0, 1]], + self.bases[:2], self.symops) self.assertEqual(orbit1, self.orbit) self.assertNotEqual(orbit2, self.orbit) + def test_is_sub_orbit(self): + orbit = Orbit(self.coords[:3], self.lattice, [[0, 1], [0, 1], [0, 1]], + self.bases, self.symops) + self.assertFalse(self.orbit.is_sub_orbit(orbit)) + orbit = Orbit([self.coords[0], self.coords[3]], self.lattice, + [[0, 1], [0, 1]], + self.bases[:2], self.symops) + self.assertFalse(self.orbit.is_sub_orbit(orbit)) + orbit = Orbit([self.coords[0]], self.lattice, [[0, 1]], + [self.bases[0]], self.symops) + self.assertTrue(self.orbit.is_sub_orbit(orbit)) + orbit = Orbit([self.coords[1]], self.lattice, [[0, 1]], + [self.bases[1]], self.symops) + self.assertTrue(self.orbit.is_sub_orbit(orbit)) + orbit = Orbit(self.coords[:2], self.lattice, [[0, 1], [0, 1]], + self.bases[:2], self.symops) + self.assertTrue(self.orbit.is_sub_orbit(orbit)) + orbit = Orbit([self.coords[3]], self.lattice, [[0, 1]], + [self.bases[0]], self.symops) + self.assertFalse(self.orbit.is_sub_orbit(orbit)) + + def test_sub_orbit_mappings(self): + orbit = Orbit(self.coords[1:], + self.lattice, [[0, 1], [0, 1], [0, 1]], + self.bases, self.symops) + self.assertEqual(len(self.orbit.sub_orbit_mappings(orbit)), 0) + orbit = Orbit(self.coords[:2], + self.lattice, [[0, 1], [0, 1]], + self.bases[:2], self.symops) + self.assertTrue( + np.array_equal(self.orbit.sub_orbit_mappings(orbit), [[0, 1]])) + orbit = Orbit([self.coords[2]], self.lattice, [[0, 1]], + [self.bases[2]], self.symops) + self.assertTrue( + np.array_equal(self.orbit.sub_orbit_mappings(orbit), [[2]])) + def test_bit_combos(self): # orbit with two symmetrically equivalent sites - self.assertEqual(len(self.orbit), 3) + self.assertEqual(len(self.orbit), 6) orbit = Orbit(self.coords[1:3], self.lattice, [[0, 1], [0, 1]], - self.bases, self.symops) + self.bases[:2], self.symops) # orbit with two symmetrically distinct sites self.assertEqual(len(orbit), 4) @@ -70,7 +109,7 @@ def test_is_orthonormal(self): b.orthonormalize() self.assertTrue(b.is_orthogonal) orbit1 = Orbit(self.coords[:2], self.lattice, [[0, 1], [0, 1]], - self.bases, self.symops) + self.bases[:2], self.symops) self.assertTrue(orbit1.basis_orthogonal) self.assertTrue(orbit1.basis_orthonormal) @@ -92,15 +131,12 @@ def test_remove_bit_combo(self): self.assertFalse(any(any(np.array_equal(equiv_bits, b) for b in b_c) for b_c in self.orbit.bit_combos)) - # check that it complains if we remove the last remaining combo - self.assertRaises(RuntimeError, self.orbit.remove_bit_combo, [1, 1]) - def test_remove_bit_combo_by_inds(self): orb1 = Orbit(self.coords[:2], self.lattice, [[0, 1], [0, 1]], - self.bases, self.symops) + self.bases[:2], self.symops) orb1.assign_ids(1, 1, 1) orb2 = Orbit(self.coords[:2], self.lattice, [[0, 1], [0, 1]], - self.bases, self.symops) + self.bases[:2], self.symops) orb2.assign_ids(1, 1, 1) bit = orb1.bit_combos[1][0] @@ -108,7 +144,8 @@ def test_remove_bit_combo_by_inds(self): orb2.remove_bit_combo(bit) self.assertTrue(all(np.array_equal(bc1, bc2) for bc1, bc2 in zip(orb1.bit_combos, orb2.bit_combos))) - self.assertRaises(RuntimeError, orb1.remove_bit_combos_by_inds, [0,1]) + self.assertRaises( + RuntimeError, orb1.remove_bit_combos_by_inds, [0, 1]) def test_eval(self): # Test cluster function evaluation with indicator basis @@ -126,7 +163,7 @@ def test_exceptions(self): self.lattice, [[0, 1], [0, 1]], self.bases, self.symops) self.assertRaises(RuntimeError, self.orbit.remove_bit_combos_by_inds, - [4]) + [6]) def test_repr(self): repr(self.orbit)