Skip to content
This repository has been archived by the owner on Jun 19, 2024. It is now read-only.

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 2, 2024
1 parent f802400 commit e2aead1
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 80 deletions.
124 changes: 62 additions & 62 deletions znframe/bonds/__init__.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,62 @@
import ase
import networkx as nx
import numpy as np
from ase.neighborlist import natural_cutoffs
from networkx.exception import NetworkXError
from pydantic import BaseModel, Field


class ASEComputeBonds(BaseModel):
single_bond_multiplier: float = Field(1.2, le=2, ge=0)
double_bond_multiplier: float = Field(0.9, le=1, ge=0)
triple_bond_multiplier: float = Field(0.0, le=1, ge=0)

def build_graph(self, atoms: ase.Atoms):
cutoffs = [
self.single_bond_multiplier,
self.double_bond_multiplier,
self.triple_bond_multiplier,
]
atoms_copy = atoms.copy()
connectivity_matrix = np.zeros((len(atoms_copy), len(atoms_copy)), dtype=int)
atoms_copy.pbc = False
distance_matrix = atoms_copy.get_all_distances(mic=False)
np.fill_diagonal(distance_matrix, np.inf)
for cutoff in cutoffs:
cutoffs = np.array(natural_cutoffs(atoms_copy, mult=cutoff))
cutoffs = cutoffs[:, None] + cutoffs[None, :]
connectivity_matrix[distance_matrix <= cutoffs] += 1
G = nx.from_numpy_array(connectivity_matrix)
return G

def update_graph_using_modifications(self, atoms: ase.Atoms):
modifications = atoms.info.get("modifications", {})
graph = atoms.connectivity
for key in modifications:
atom_1, atom_2 = key
weight = modifications[key]
if weight == 0:
self.remove_edge(graph, atom_1, atom_2)
else:
graph.add_edge(atom_1, atom_2, weight=weight)

@staticmethod
def remove_edge(graph, atom_1, atom_2):
try:
graph.remove_edge(atom_1, atom_2)
except NetworkXError:
pass

def get_bonds(self, graph):
bonds = []
for edge in graph.edges:
bonds.append((edge[0], edge[1], graph.edges[edge]["weight"]))
return bonds

def update_bond_order(self, atoms: ase.Atoms, particles: list[int], order: int):
if len(particles) != 2:
raise ValueError("Exactly two particles must be selected")
modifications = atoms.info.get("modifications", {})
sorted_particles = tuple(sorted(particles))
modifications[sorted_particles] = order
atoms.info["modifications"] = modifications
import ase
import networkx as nx
import numpy as np
from ase.neighborlist import natural_cutoffs
from networkx.exception import NetworkXError
from pydantic import BaseModel, Field


class ASEComputeBonds(BaseModel):
single_bond_multiplier: float = Field(1.2, le=2, ge=0)
double_bond_multiplier: float = Field(0.9, le=1, ge=0)
triple_bond_multiplier: float = Field(0.0, le=1, ge=0)

def build_graph(self, atoms: ase.Atoms):
cutoffs = [
self.single_bond_multiplier,
self.double_bond_multiplier,
self.triple_bond_multiplier,
]
atoms_copy = atoms.copy()
connectivity_matrix = np.zeros((len(atoms_copy), len(atoms_copy)), dtype=int)
atoms_copy.pbc = False
distance_matrix = atoms_copy.get_all_distances(mic=False)
np.fill_diagonal(distance_matrix, np.inf)
for cutoff in cutoffs:
cutoffs = np.array(natural_cutoffs(atoms_copy, mult=cutoff))
cutoffs = cutoffs[:, None] + cutoffs[None, :]
connectivity_matrix[distance_matrix <= cutoffs] += 1
G = nx.from_numpy_array(connectivity_matrix)
return G

def update_graph_using_modifications(self, atoms: ase.Atoms):
modifications = atoms.info.get("modifications", {})
graph = atoms.connectivity
for key in modifications:
atom_1, atom_2 = key
weight = modifications[key]
if weight == 0:
self.remove_edge(graph, atom_1, atom_2)
else:
graph.add_edge(atom_1, atom_2, weight=weight)

@staticmethod
def remove_edge(graph, atom_1, atom_2):
try:
graph.remove_edge(atom_1, atom_2)
except NetworkXError:
pass

def get_bonds(self, graph):
bonds = []
for edge in graph.edges:
bonds.append((edge[0], edge[1], graph.edges[edge]["weight"]))
return bonds

def update_bond_order(self, atoms: ase.Atoms, particles: list[int], order: int):
if len(particles) != 2:
raise ValueError("Exactly two particles must be selected")
modifications = atoms.info.get("modifications", {})
sorted_particles = tuple(sorted(particles))
modifications[sorted_particles] = order
atoms.info["modifications"] = modifications
33 changes: 15 additions & 18 deletions znframe/frame.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from attrs import define, field, cmp_using, Factory
from attrs import define, field, cmp_using
import attrs
import numpy as np
import ase.cell
from ase.data.colors import jmol_colors
from copy import deepcopy
import json
import networkx as nx

from znframe.bonds import ASEComputeBonds


def _cell_to_array(cell: np.ndarray | ase.cell.Cell) -> np.ndarray:
if isinstance(cell, np.ndarray):
return cell
Expand Down Expand Up @@ -39,33 +39,32 @@ def _ndarray_to_list(array: dict | np.ndarray) -> dict | list:

@define
class Frame:
numbers: np.ndarray = field(
converter=_list_to_array, eq=cmp_using(np.array_equal)
)
numbers: np.ndarray = field(converter=_list_to_array, eq=cmp_using(np.array_equal))
positions: np.ndarray = field(
converter=_list_to_array, eq=cmp_using(np.array_equal)
)
)

connectivity: np.ndarray = field(
converter=_list_to_array, eq=cmp_using(np.array_equal), default=None
)
)

arrays: dict[str, np.ndarray] = field(
converter=_list_to_array, eq=False, factory=dict
)
)
info: dict[str, float | int | np.ndarray] = field(
converter=_list_to_array, eq=False, factory=dict
)
)

pbc: np.ndarray = field(
converter=_list_to_array, eq=cmp_using(np.array_equal), default=np.array([True, True, True])
)
converter=_list_to_array,
eq=cmp_using(np.array_equal),
default=np.array([True, True, True]),
)
cell: np.ndarray = field(
converter=_cell_to_array, eq=cmp_using(np.array_equal), default=np.zeros(3)
)
)

def __attrs_post_init__(self):

if self.connectivity is None:
ase_bond_calculator = ASEComputeBonds()
self.connectivity = ase_bond_calculator.build_graph(self.to_atoms())
Expand All @@ -74,11 +73,9 @@ def __attrs_post_init__(self):
if "colors" not in self.arrays:
self.arrays["colors"] = [
rgb2hex(jmol_colors[number]) for number in self.numbers
]
if "radii" not in self.arrays:
self.arrays["radii"] = [
get_radius(number) for number in self.numbers
]
if "radii" not in self.arrays:
self.arrays["radii"] = [get_radius(number) for number in self.numbers]

@classmethod
def from_atoms(cls, atoms: ase.Atoms):
Expand Down Expand Up @@ -128,6 +125,6 @@ def rgb2hex(value):
r, g, b = np.array(value * 255, dtype=int)
return "#%02x%02x%02x" % (r, g, b)


def get_radius(value):
return (0.25 * (2 - np.exp(-0.2 * value)),)

0 comments on commit e2aead1

Please sign in to comment.