Skip to content

Commit

Permalink
update: address pr comments + config class
Browse files Browse the repository at this point in the history
  • Loading branch information
VsevolodX committed Dec 8, 2024
1 parent 6565a48 commit a50c2f8
Showing 1 changed file with 86 additions and 131 deletions.
217 changes: 86 additions & 131 deletions src/py/mat3ra/made/tools/analyze/coordination.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,80 @@ def are_bond_directions_similar(directions1: np.ndarray, directions2: np.ndarray

return True

@staticmethod
def find_missing_directions(
vectors: List[np.ndarray],
element: str,
templates: Dict[str, List[np.ndarray]],
angle_tolerance: float = 0.1,
max_bonds_to_passivate: int = 1,
) -> List[List[float]]:
"""
Reconstruct missing bonds for a single element.
Args:
vectors (List[np.ndarray]): List of bond vectors for the atom.
element (str): Chemical element of the atom.
templates (Dict[str, List[np.ndarray]]): Dictionary of bond templates for each element.
angle_tolerance (float): Tolerance for comparing angles between bond vectors.
max_bonds_to_passivate (int): Maximum number of bonds to passivate for the atom.
Returns:
List[List[float]]: List of reconstructed bond vectors.
"""
if element not in templates:
return []

existing_vectors = np.array(vectors) if vectors else np.empty((0, 3))
max_coordination_number = len(templates[element][0])

if len(existing_vectors) >= max_coordination_number:
return []

best_missing = None
best_match_count = -1

for template in templates[element]:
if existing_vectors.size == 0:
match_count = 0
else:
dot_matrix = np.dot(template, existing_vectors.T)
cosine_matrix = dot_matrix / (
np.linalg.norm(template, axis=1)[:, None] * np.linalg.norm(existing_vectors, axis=1)
)
angles_matrix = np.arccos(np.clip(cosine_matrix, -1.0, 1.0))

matches = np.any(angles_matrix < angle_tolerance, axis=1)
match_count = np.sum(matches)

missing = template[~matches] if existing_vectors.size != 0 else template

if match_count > best_match_count:
best_match_count = match_count
best_missing = missing

if best_missing is not None:
num_bonds_to_add = min(
len(best_missing),
max_bonds_to_passivate,
max_coordination_number - len(existing_vectors),
)
return best_missing[:num_bonds_to_add].tolist()

return []


class CrystalSite(BaseModel):
# element: str
# coordinate: List[float]
nearest_neighbor_vectors: []
nearest_neighbor_vectors: List[np.ndarray] = []
# coordination_number: int = 0
# see https://www.cryst.ehu.es/cgi-bin/cryst/programs/nph-wp-list for an example
wyckoff_letter: Optional[str] = None

class Config:
arbitrary_types_allowed = True

@property
def coordination_number(self):
return len(self.nearest_neighbor_vectors)
Expand All @@ -75,47 +140,18 @@ class CrystalSiteList(ArrayWithIds):


class MaterialWithCrystalSites(Material):
crystal_sites: CrystalSiteList = CrystalSiteList()
crystal_sites: CrystalSiteList = CrystalSiteList(values=[])

def __init__(self, **data):
super().__init__(**data)
self.nearest_neighbor_vectors = self.get_nearest_neighbors_vectors()
self.nearest_neighbor_vectors = self.get_neighbors_vectors_for_all_sites(cutoff=3.0)
self.crystal_sites = CrystalSiteList(nearest_neighbor_vectors=self.nearest_neighbor_vectors)

@property
def coordinates_as_kdtree(self):
return cKDTree(self.basis.coordinates.values)

@decorator_handle_periodic_boundary_conditions(0.25)
def get_nearest_neighbors_vectors(
self,
cutoff: float = 3.0,
nearest_only: bool = True,
) -> ArrayWithIds:
"""
Calculate the vectors to the nearest neighbors for each atom in the material.
Args:
material (Material): Material object to calculate coordination numbers for.
indices (List[int]): List of atom indices to calculate coordination numbers for.
cutoff (float): The maximum cutoff radius for identifying neighbors.
nearest_only (bool): If True, only consider the first shell of neighbors.
Returns:
ArrayWithIds: Array of vectors to the nearest neighbors for each atom.
"""
new_material = self.clone()
new_material.to_cartesian()
coordinates = np.array(new_material.basis.coordinates.values)
kd_tree = cKDTree(coordinates)

nearest_neighbors = self.get_nearest_neighbors_for_all_sites()
vectors = [
coordinates[neighbor.id] - coordinates[neighbor.id]
for neighbor in nearest_neighbors.to_array_of_values_with_ids()
]
return ArrayWithIds(values=vectors, ids=nearest_neighbors.ids)

@decorator_handle_periodic_boundary_conditions(cutoff=0.25)
def get_neighbors_vectors_for_site(
self, site_index: int, cutoff: float = 3.0, max_number_of_neighbors: Optional[int] = None
):
Expand All @@ -131,7 +167,6 @@ def get_neighbors_vectors_for_all_sites(self, cutoff: float = 3.0, max_number_of
nearest_neighbors.add_item(vectors, site_index)
return nearest_neighbors

@decorator_handle_periodic_boundary_conditions(0.25)
def get_neighbors_for_site(
self,
site_index: int,
Expand Down Expand Up @@ -198,6 +233,17 @@ def get_coordination_numbers(self, cutoff: float = 3.0):
coordination_numbers = ArrayWithIds(values=nearest_neighbors)
return coordination_numbers

def find_missing_bonds_for_all_sites(self, templates: Dict[str, List[np.ndarray]]) -> Dict[int, List[List[float]]]:

missing_bonds = {}
for idx, (vectors) in enumerate(self.nearest_neighbor_vectors):
reconstructed_bonds = BondDirections.find_missing_directions(
vectors, self.basis.elements.values[idx], templates
)
if reconstructed_bonds:
missing_bonds[idx] = reconstructed_bonds
return missing_bonds


def get_voronoi_nearest_neighbors_atom_indices(
material: Material,
Expand Down Expand Up @@ -334,100 +380,6 @@ def find_bond_directions_for_element(
return unique_templates


def reconstruct_missing_bonds_for_element(
vectors: List[np.ndarray],
element: str,
templates: Dict[str, List[np.ndarray]],
angle_tolerance: float = 0.1,
max_bonds_to_passivate: int = 1,
) -> List[List[float]]:
"""
Reconstruct missing bonds for a single element.
Args:
vectors (List[np.ndarray]): List of bond vectors for the atom.
element (str): Chemical element of the atom.
templates (Dict[str, List[np.ndarray]]): Dictionary of bond templates for each element.
angle_tolerance (float): Tolerance for comparing angles between bond vectors.
max_bonds_to_passivate (int): Maximum number of bonds to passivate for the atom.
Returns:
List[List[float]]: List of reconstructed bond vectors.
"""
if element not in templates:
return []

existing_vectors = np.array(vectors) if vectors else np.empty((0, 3))
max_coordination_number = len(templates[element][0])

if len(existing_vectors) >= max_coordination_number:
return []

best_missing = None
best_match_count = -1

for template in templates[element]:
if existing_vectors.size == 0:
match_count = 0
else:
dot_matrix = np.dot(template, existing_vectors.T)
cosine_matrix = dot_matrix / (
np.linalg.norm(template, axis=1)[:, None] * np.linalg.norm(existing_vectors, axis=1)
)
angles_matrix = np.arccos(np.clip(cosine_matrix, -1.0, 1.0))

matches = np.any(angles_matrix < angle_tolerance, axis=1)
match_count = np.sum(matches)

missing = template[~matches] if existing_vectors.size != 0 else template

if match_count > best_match_count:
best_match_count = match_count
best_missing = missing

if best_missing is not None:
num_bonds_to_add = min(
len(best_missing),
max_bonds_to_passivate,
max_coordination_number - len(existing_vectors),
)
return best_missing[:num_bonds_to_add].tolist()

return []


def reconstruct_missing_bonds(
nearest_neighbor_vectors: List[List[np.ndarray]],
chemical_elements: List[str],
templates: Dict[str, List[np.ndarray]],
angle_tolerance: float = 0.1,
max_bonds_to_passivate: int = 1,
) -> Dict[int, List[List[float]]]:
"""
Reconstruct missing bonds for all undercoordinated atoms.
Args:
nearest_neighbor_vectors (List[List[np.ndarray]]): List of bond vectors for each atom.
chemical_elements (List[str]): List of chemical elements for each atom.
templates (Dict[str, List[np.ndarray]]): Dictionary of bond templates for each element.
angle_tolerance (float): Tolerance for comparing angles between bond vectors.
max_bonds_to_passivate (int): Maximum number of bonds to passivate for each undercoordinated atom.
Returns:
Dict[int, List[List[float]]]: Dictionary mapping atom indices to reconstructed bond vectors.
"""
missing_bonds = {}

for idx, (vectors, element) in enumerate(zip(nearest_neighbor_vectors, chemical_elements)):
reconstructed_bonds = reconstruct_missing_bonds_for_element(
vectors, element, templates, angle_tolerance, max_bonds_to_passivate
)
if reconstructed_bonds:
missing_bonds[idx] = reconstructed_bonds

return missing_bonds


####################################################################################################
# Radial Distribution Function (RDF) Analysis
####################################################################################################
Expand All @@ -437,6 +389,9 @@ class RadicalDistributionFunction(BaseModel):
rdf: np.ndarray
bin_centers: np.ndarray

class Config:
arbitrary_types_allowed = True

@classmethod
def from_material(cls, material: Material, cutoff: float = 10.0, bin_size: float = 0.1):
analyzer = BaseMaterialAnalyzer(material)
Expand Down Expand Up @@ -491,9 +446,9 @@ def first_peak_width(self):
def first_peak_distance(self):
return self.bin_centers[self.first_peak_index]

def is_within_first_peak(self, distance: float):
def is_within_first_peak(self, distance: float, tolerance: float = 0.1) -> bool:
return (
self.first_peak_distance - 0.5 * self.first_peak_width
self.first_peak_distance - 0.5 * self.first_peak_width - tolerance
< distance
< self.first_peak_distance + 0.5 * self.first_peak_width
< self.first_peak_distance + 0.5 * self.first_peak_width + tolerance
)

0 comments on commit a50c2f8

Please sign in to comment.