Skip to content

Commit

Permalink
Merge pull request #159 from lbluque/pytests
Browse files Browse the repository at this point in the history
(Almost) all unit tests to pytest
  • Loading branch information
lbluque authored Mar 1, 2022
2 parents 85a5ac2 + 908cbac commit 5519bed
Show file tree
Hide file tree
Showing 42 changed files with 1,879 additions and 1,690 deletions.
18 changes: 11 additions & 7 deletions smol/cofe/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ class RegressionData:

module: str
estimator_name: str
parameters: dict
feature_matrix: np.ndarray
property_vector: np.ndarray
parameters: dict

@classmethod
def from_object(cls, estimator, feature_matrix, property_vector,
Expand Down Expand Up @@ -61,9 +61,9 @@ def from_object(cls, estimator, feature_matrix, property_vector,

return cls(module=estimator.__module__,
estimator_name=estimator_name,
parameters=parameters,
feature_matrix=feature_matrix,
property_vector=property_vector)
property_vector=property_vector,
parameters=parameters,)

@classmethod
def from_sklearn(cls, estimator, feature_matrix, property_vector):
Expand All @@ -81,9 +81,9 @@ def from_sklearn(cls, estimator, feature_matrix, property_vector):
"""
return cls(module=estimator.__module__,
estimator_name=estimator.__class__.__name__,
parameters=estimator.get_params(),
feature_matrix=feature_matrix,
property_vector=property_vector)
property_vector=property_vector,
parameters=estimator.get_params())


class ClusterExpansion(MSONable):
Expand Down Expand Up @@ -218,7 +218,7 @@ def predict(self, structure, normalize=False):
"""
corrs = self.cluster_subspace.corr_from_structure(
structure, normalized=normalize)
return np.dot(corrs, self.coefs)
return np.dot(self.coefs, corrs)

def prune(self, threshold=0, with_multiplicity=False):
"""Remove fit coefficients or ECI's with small values.
Expand Down Expand Up @@ -253,6 +253,10 @@ def prune(self, threshold=0, with_multiplicity=False):
self._feat_matrix = self._feat_matrix[:, ids_complement]
self._eci = None # Reset

def copy(self):
"""Return a copy of self."""
return ClusterExpansion.from_dict(self.as_dict())

def __str__(self):
"""Pretty string for printing."""
corr = np.zeros(self.cluster_subspace.num_corr_functions)
Expand All @@ -270,7 +274,7 @@ def __str__(self):
s += f' [Orbit] id: {str(0):<3}\n'
s += ' bit eci\n'
s += f' {"[X]":<10}{ecis[0]:<4.3}\n'
for orbit in self.cluster_subspace.iterorbits():
for orbit in self.cluster_subspace.orbits:
s += f' [Orbit] id: {orbit.bit_id:<3} size: ' \
f'{len(orbit.bits):<3} radius: ' \
f'{orbit.base_cluster.diameter:<4.3}\n'
Expand Down
2 changes: 1 addition & 1 deletion smol/cofe/space/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def rotate(self, angle, index1=0, index2=1):
"This basis has a non-uniform measure, rotations are not "
"implemented to handle this.\n The operation will still be "
"carried out, but it is recommended to run orthonormalize "
"again if the basis was originally so.")
"again if the basis was originally so.", UserWarning)

if len(self.site_space) == 2:
self._f_array[1] *= -1
Expand Down
4 changes: 2 additions & 2 deletions smol/cofe/space/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def from_sites(cls, sites):
"""Create a cluster from a list of pymatgen Sites."""
return cls([s.frac_coords for s in sites], sites[0].lattice)

@property
@property # TODO deprecate this
def size(self):
"""Get number of sites in the cluster."""
return len(self.sites)
Expand All @@ -78,7 +78,7 @@ def assign_ids(self, cluster_id):

def __len__(self):
"""Get size of a cluster. The number of sites."""
return self.size
return len(self.sites)

def __eq__(self, other):
"""Check equivalency of clusters considering symmetry."""
Expand Down
36 changes: 10 additions & 26 deletions smol/cofe/space/clusterspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,12 +276,6 @@ def orbits(self):
return [orbit for _, orbits
in sorted(self._orbits.items()) for orbit in orbits]

def iterorbits(self):
"""Yield orbits."""
for _, orbits in sorted(self._orbits.items()):
for orbit in orbits:
yield orbit

@property
def orbits_by_size(self):
"""Get dictionary of orbits with key being the orbit size."""
Expand All @@ -290,27 +284,17 @@ def orbits_by_size(self):
@property
def orbit_multiplicities(self):
"""Get the crystallographic multiplicities for each orbit."""
mults = [1] + [orb.multiplicity for orb in self.iterorbits()]
mults = [1] + [orb.multiplicity for orb in self.orbits]
return np.array(mults)

@property
def all_bit_combos(self):
"""Return flattened all bit_combos for each correlation function.
A list of all bit_combos of length self.num_corr_functions.
This allows to obtain a bit combo by bit id (empty cluster has None)
"""
return [None] + [combos for orbit in self.orbits
for combos in orbit.bit_combos]

@property
def num_functions_per_orbit(self):
"""Get the number of correlation functions for each orbit.
The list returned is of length total number of orbits, each entry is
the total number of correlation functions assocaited with that orbit.
"""
return np.array([len(orbit) for orbit in self.iterorbits()])
return np.array([len(orbit) for orbit in self.orbits])

@property
def function_orbit_ids(self):
Expand All @@ -320,7 +304,7 @@ def function_orbit_ids(self):
in the list since they are not associated with any orbit.
"""
func_orb_ids = [0]
for orbit in self.iterorbits():
for orbit in self.orbits:
func_orb_ids += len(orbit) * [orbit.id, ]
return np.array(func_orb_ids)

Expand Down Expand Up @@ -359,12 +343,12 @@ def function_total_multiplicities(self):
@property
def basis_orthogonal(self):
"""Check if the orbit basis defined is orthogonal."""
return all(orb.basis_orthogonal for orb in self.iterorbits())
return all(orb.basis_orthogonal for orb in self.orbits)

@property
def basis_orthonormal(self):
"""Check if the orbit basis is orthonormal."""
return all(orb.basis_orthonormal for orb in self.iterorbits())
return all(orb.basis_orthonormal for orb in self.orbits)

@property
def external_terms(self):
Expand Down Expand Up @@ -731,7 +715,7 @@ def change_site_bases(self, new_basis, orthonormal=False):
orthonormal (bool):
option to orthonormalize all new site basis sets
"""
for orbit in self.iterorbits():
for orbit in self.orbits:
orbit.transform_site_bases(new_basis, orthonormal)

def remove_orbits(self, orbit_ids):
Expand Down Expand Up @@ -799,7 +783,7 @@ def remove_orbit_bit_combos(self, bit_ids):
empty_orbit_ids = []
bit_ids = np.array(bit_ids, dtype=int)

for orbit in self.iterorbits():
for orbit in self.orbits:
first_id = orbit.bit_id
last_id = orbit.bit_id + len(orbit)
to_remove = bit_ids[bit_ids >= first_id]
Expand All @@ -820,7 +804,7 @@ def remove_orbit_bit_combos(self, bit_ids):

def copy(self):
"""Deep copy of instance."""
return deepcopy(self)
return ClusterSubspace.from_dict(self.as_dict())

def structure_site_mapping(self, supercell, structure):
"""Get structure site mapping.
Expand Down Expand Up @@ -859,7 +843,7 @@ def get_sub_orbits(self, orbit_id, level=1, min_size=1):
"""
if orbit_id == 0:
return []
size = self.orbits[orbit_id - 1].base_cluster.size
size = len(self.orbits[orbit_id - 1].base_cluster)
if level is None or level < 0 or size - level - 1 < 0:
stop = 0
elif min_size > size - level:
Expand Down Expand Up @@ -1109,7 +1093,7 @@ def _gen_orbit_indices(self, scmatrix):

ts = lattice_points_in_supercell(scmatrix)
orbit_indices = []
for orbit in self.iterorbits():
for orbit in self.orbits:
prim_fcoords = np.array([c.sites for c in orbit.clusters])
fcoords = np.dot(prim_fcoords, prim_to_supercell)
# tcoords contains all the coordinates of the symmetrically
Expand Down
104 changes: 59 additions & 45 deletions smol/cofe/space/orbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(self, sites, lattice, bits, site_bases, structure_symops):
# lazy generation of properties
self._equiv = None
self._symops = None
self._permutations = None
self._bit_combos = None
self._basis_arrs = None
self._corr_tensors = None
Expand Down Expand Up @@ -126,26 +127,23 @@ def multiplicity(self):

@property
def bit_combos(self):
"""Get list of site bit orderings.
"""Get tuple of site bit orderings.
tuple of ndarrays, each array is a set of symmetrically equivalent bit
orderings represented by row. Bit combos represent non-constant site
function orderings.
"""
if self._bit_combos is not None:
return self._bit_combos

# get all the bit symmetry operations
bit_ops = tuple(set(bit_op for _, bit_op in self.cluster_symops))
all_combos = []
for bit_combo in product(*self.bits):
if bit_combo not in chain(*all_combos):
bit_combo = np.array(bit_combo)
new_bits = list(set(
tuple(bit_combo[np.array(bit_op)]) for bit_op in bit_ops))
all_combos.append(new_bits)
self._bit_combos = tuple(
np.array(c, dtype=np.int_) for c in all_combos)
if self._bit_combos is None:
# get all the bit symmetry operations
all_combos = []
for bit_combo in product(*self.bits):
if not any(np.array_equal(bit_combo, bc)
for bc in chain(*all_combos)):
bit_combo = np.array(bit_combo, dtype=np.int_)
new_bits = np.unique(
bit_combo[self.cluster_permutations], axis=0)
all_combos.append(new_bits)
self._bit_combos = tuple(all_combos)
return self._bit_combos

@property
Expand Down Expand Up @@ -179,30 +177,23 @@ def clusters(self):
def cluster_symops(self):
"""Get symmetry operations that map a cluster to its periodic image.
Each element is a tuple of (pymatgen.core.operations.Symop, mapping)
where mapping is a tuple such that
Symop.operate(sites) = sites[mapping]
(after translation back to unit cell)
Each element is a pymatgen.core.operations.Symop.
"""
if self._symops:
return self._symops

self._symops = []
for symop in self.structure_symops:
new_sites = symop.operate_multi(self.base_cluster.sites)
cluster = Cluster(new_sites, self.base_cluster.lattice)
if cluster == self.base_cluster:
recenter = np.round(
self.base_cluster.centroid - cluster.centroid)
c_sites = cluster.sites + recenter
mapping = tuple(coord_list_mapping(
self.base_cluster.sites, c_sites, atol=SITE_TOL))
self._symops.append((symop, mapping))
if self._symops is None:
self._gen_cluster_symops()
return self._symops

if len(self._symops) * self.multiplicity != len(self.structure_symops):
raise SymmetryError(SYMMETRY_ERROR_MESSAGE)
@property
def cluster_permutations(self):
"""Get the symmetrical site permutations that map a cluster to itself.
return self._symops
A permutation is a mapping such that for a give symop in cluster_symops
Symop.operate(sites) = sites[mapping] (after translation back to unit
cell)
"""
if self._permutations is None:
self._gen_cluster_symops()
return self._permutations

@property
def basis_arrays(self):
Expand Down Expand Up @@ -313,8 +304,8 @@ def remove_bit_combo(self, bits): # seems like this is no longer used?
def remove_bit_combos_by_inds(self, inds):
"""Remove bit combos by their indices in the bit_combo list."""
if max(inds) > len(self.bit_combos) - 1:
raise RuntimeError(
f"Some indices {inds} out of ranges for total "
raise ValueError(
f"Some indices {inds} out of range for total "
f"{len(self._bit_combos)} bit combos")

self._bit_combos = tuple(
Expand Down Expand Up @@ -355,7 +346,7 @@ def is_sub_orbit(self, orbit):
Returns:
bool: True if the clusters of given orbit are subclusters.
"""
if self.base_cluster.size <= orbit.base_cluster.size:
if len(self.base_cluster) <= len(orbit.base_cluster):
return False
elif not np.all(sp in self.site_spaces for sp in orbit.site_spaces):
return False
Expand All @@ -365,7 +356,7 @@ def is_sub_orbit(self, orbit):
self.base_cluster.lattice)
in orbit.clusters
for inds in combinations(
range(self.base_cluster.size), orbit.base_cluster.size))
range(len(self.base_cluster)), len(orbit.base_cluster)))

return match

Expand All @@ -384,8 +375,10 @@ def sub_orbit_mappings(self, orbit):
self.base_cluster.sites[indices] = orbit.base_cluster.sites
"""
indsets = np.array(list(combinations(
(i for i, space in enumerate(self.site_spaces)
if space in orbit.site_spaces), len(orbit.site_spaces))))
(
i for i, space in enumerate(self.site_spaces)
if space in orbit.site_spaces), len(orbit.site_spaces)
)), dtype=int)

mappings = []
for cluster in self.clusters:
Expand All @@ -395,16 +388,37 @@ def sub_orbit_mappings(self, orbit):
recenter = np.round(centroid - orbit.base_cluster.centroid)
c_sites = orbit.base_cluster.sites + recenter
if is_coord_subset(c_sites, cluster.sites):
mappings.append(
coord_list_mapping(
c_sites, cluster.sites, atol=SITE_TOL))
mappings.append(coord_list_mapping(
c_sites, cluster.sites, atol=SITE_TOL))

if len(mappings) == 0 and self.is_sub_orbit(orbit):
raise RuntimeError(
"The given orbit is a suborbit, but no site mappings were "
"found!\n Something is very wrong here!")
return np.unique(mappings, axis=0)

def _gen_cluster_symops(self):
"""Generate the cluster SymOps and decoration permutations."""
symops = []
permutations = []
for symop in self.structure_symops:
new_sites = symop.operate_multi(self.base_cluster.sites)
cluster = Cluster(new_sites, self.base_cluster.lattice)
if cluster == self.base_cluster:
recenter = np.round(
self.base_cluster.centroid - cluster.centroid)
c_sites = cluster.sites + recenter
mapping = coord_list_mapping(
self.base_cluster.sites, c_sites, atol=SITE_TOL)
symops.append(symop)
permutations.append(mapping)

self._permutations = np.unique(permutations, axis=0)
self._symops = tuple(symops)

if len(self._symops) * self.multiplicity != len(self.structure_symops):
raise SymmetryError(SYMMETRY_ERROR_MESSAGE)

def reset_bases(self):
"""Reset cached basis function array and correlation tensors."""
self._basis_arrs = None
Expand Down
Loading

0 comments on commit 5519bed

Please sign in to comment.