Skip to content

Commit

Permalink
update: remove tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
VsevolodX committed Aug 12, 2024
1 parent 5f2dcf4 commit 8ce0856
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 78 deletions.
8 changes: 4 additions & 4 deletions src/py/mat3ra/made/tools/build/defect/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
)
from ....utils import get_center_of_coordinates
from ...utils import transform_coordinate_to_supercell
from ...utils.coordinate import CoordinateConditionBuilder
from ...utils import coordinate as CoordinateCondition
from ..utils import merge_materials
from ..slab import SlabConfiguration, create_slab, Termination
from ..supercell import create_supercell
Expand Down Expand Up @@ -437,7 +437,7 @@ def condition(coordinate: List[float]):
return self.merge_slab_and_defect(island_material, new_material)

def _generate(self, configuration: _ConfigurationType) -> List[_GeneratedItemType]:
condition_callable, _ = configuration.condition
condition_callable = configuration.condition.condition
return [
self.create_island(
material=configuration.crystal,
Expand Down Expand Up @@ -593,10 +593,10 @@ def create_terrace(
)

normalized_direction_vector = self._calculate_cut_direction_vector(material, cut_direction)
condition, _ = CoordinateConditionBuilder.plane(
condition = CoordinateCondition.PlaneCoordinateCondition(
plane_normal=normalized_direction_vector,
plane_point_coordinate=pivot_coordinate,
)
).condition
atoms_within_terrace = filter_by_condition_on_coordinates(
material=material_with_additional_layers,
condition=condition,
Expand Down
27 changes: 20 additions & 7 deletions src/py/mat3ra/made/tools/build/defect/configuration.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
from typing import Optional, List, Any, Callable, Dict, Tuple, Union
from typing import Optional, List, Union
from pydantic import BaseModel

from mat3ra.code.entity import InMemoryEntity
from mat3ra.made.material import Material

from ...analyze import get_closest_site_id_from_coordinate, get_atomic_coordinates_extremum
from ...utils.coordinate import CoordinateConditionBuilder
from ...utils.coordinate import (
CylinderCoordinateCondition,
SphereCoordinateCondition,
BoxCoordinateCondition,
TriangularPrismCoordinateCondition,
PlaneCoordinateCondition,
)
from .enums import PointDefectTypeEnum, SlabDefectTypeEnum, AtomPlacementMethodEnum, ComplexDefectTypeEnum


class BaseDefectConfiguration(BaseModel):
# TODO: fix arbitrary_types_allowed error and set Material class type
crystal: Any = None
crystal: Material = None

class Config:
arbitrary_types_allowed = True

@property
def _json(self):
Expand Down Expand Up @@ -169,16 +177,21 @@ class IslandSlabDefectConfiguration(SlabDefectConfiguration):
"""

defect_type: SlabDefectTypeEnum = SlabDefectTypeEnum.ISLAND
condition: Optional[Tuple[Callable[[List[float]], bool], Dict]] = CoordinateConditionBuilder().cylinder()
condition: Union[
CylinderCoordinateCondition,
SphereCoordinateCondition,
BoxCoordinateCondition,
TriangularPrismCoordinateCondition,
PlaneCoordinateCondition,
] = CylinderCoordinateCondition()

@property
def _json(self):
_, condition_json = self.condition
return {
**super()._json,
"type": self.get_cls_name(),
"defect_type": self.defect_type.name,
"condition": condition_json,
"condition": self.condition.to_json(),
}


Expand Down
120 changes: 55 additions & 65 deletions src/py/mat3ra/made/tools/utils/coordinate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Place all functions acting on coordinates
from typing import Callable, Dict, List, Tuple
from typing import Dict, List

import numpy as np
from pydantic import BaseModel, Field


def is_coordinate_in_cylinder(
Expand Down Expand Up @@ -154,72 +155,61 @@ def is_coordinate_behind_plane(
return np.dot(np_plane_normal, np_coordinate - np_plane_point) < 0


class CoordinateConditionBuilder:
@staticmethod
def create_condition(condition_type: str, evaluation_func: Callable, **kwargs) -> Tuple[Callable, Dict]:
condition_json = {"type": condition_type, **kwargs}
return lambda coordinate: evaluation_func(coordinate, **kwargs), condition_json

@staticmethod
def cylinder(center_position=None, radius: float = 0.25, min_z: float = 0, max_z: float = 1):
if center_position is None:
center_position = [0.5, 0.5]
return CoordinateConditionBuilder.create_condition(
condition_type="cylinder",
evaluation_func=is_coordinate_in_cylinder,
center_position=center_position,
radius=radius,
min_z=min_z,
max_z=max_z,
)
class CoordinateCondition(BaseModel):
def condition(self, coordinate: List[float]) -> bool:
raise NotImplementedError

@staticmethod
def sphere(center_position=None, radius: float = 0.25):
if center_position is None:
center_position = [0.5, 0.5, 0.5]
return CoordinateConditionBuilder.create_condition(
condition_type="sphere",
evaluation_func=is_coordinate_in_sphere,
center_position=center_position,
radius=radius,
)
def to_json(self) -> Dict:
return self.dict()

@staticmethod
def triangular_prism(
position_on_surface_1: List[float] = [0, 0],
position_on_surface_2: List[float] = [1, 0],
position_on_surface_3: List[float] = [0, 1],
min_z: float = 0,
max_z: float = 1,
):
return CoordinateConditionBuilder.create_condition(
condition_type="prism",
evaluation_func=is_coordinate_in_triangular_prism,
coordinate_1=position_on_surface_1,
coordinate_2=position_on_surface_2,
coordinate_3=position_on_surface_3,
min_z=min_z,
max_z=max_z,
)

@staticmethod
def box(min_coordinate=None, max_coordinate=None):
if max_coordinate is None:
max_coordinate = [1, 1, 1]
if min_coordinate is None:
min_coordinate = [0, 0, 0]
return CoordinateConditionBuilder.create_condition(
condition_type="box",
evaluation_func=is_coordinate_in_box,
min_coordinate=min_coordinate,
max_coordinate=max_coordinate,
)
class CylinderCoordinateCondition(CoordinateCondition):
center_position: List[float] = Field(default_factory=lambda: [0.5, 0.5])
radius: float = 0.25
min_z: float = 0
max_z: float = 1

def condition(self, coordinate: List[float]) -> bool:
return is_coordinate_in_cylinder(coordinate, self.center_position, self.radius, self.min_z, self.max_z)


class SphereCoordinateCondition(CoordinateCondition):
center_position: List[float] = Field(default_factory=lambda: [0.5, 0.5])
radius: float = 0.25

def condition(self, coordinate: List[float]) -> bool:
return is_coordinate_in_sphere(coordinate, self.center_position, self.radius)

@staticmethod
def plane(plane_normal: List[float], plane_point_coordinate: List[float]):
return CoordinateConditionBuilder.create_condition(
condition_type="plane",
evaluation_func=is_coordinate_behind_plane,
plane_normal=plane_normal,
plane_point_coordinate=plane_point_coordinate,

class BoxCoordinateCondition(CoordinateCondition):
min_coordinate: List[float] = Field(default_factory=lambda: [0, 0, 0])
max_coordinate: List[float] = Field(default_factory=lambda: [1, 1, 1])

def condition(self, coordinate: List[float]) -> bool:
return is_coordinate_in_box(coordinate, self.min_coordinate, self.max_coordinate)


class TriangularPrismCoordinateCondition(CoordinateCondition):
position_on_surface_1: List[float] = [0, 0]
position_on_surface_2: List[float] = [1, 0]
position_on_surface_3: List[float] = [0, 1]
min_z: float = 0
max_z: float = 1

def condition(self, coordinate: List[float]) -> bool:
return is_coordinate_in_triangular_prism(
coordinate,
self.position_on_surface_1,
self.position_on_surface_2,
self.position_on_surface_3,
self.min_z,
self.max_z,
)


class PlaneCoordinateCondition(CoordinateCondition):
plane_normal: List[float]
plane_point_coordinate: List[float]

def condition(self, coordinate: List[float]) -> bool:
return is_coordinate_behind_plane(coordinate, self.plane_normal, self.plane_point_coordinate)
6 changes: 4 additions & 2 deletions tests/py/unit/test_tools_build_defect.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
PointDefectPairConfiguration,
TerraceSlabDefectConfiguration,
)
from mat3ra.made.tools.utils.coordinate import CoordinateConditionBuilder
from mat3ra.made.tools.utils import coordinate as CoordinateCondition
from mat3ra.utils import assertion as assertion_utils

from .fixtures import SLAB_001, SLAB_111
Expand Down Expand Up @@ -114,7 +114,9 @@ def test_create_crystal_site_adatom():


def test_create_island():
condition = CoordinateConditionBuilder.cylinder(center_position=[0.625, 0.5], radius=0.25, min_z=0, max_z=1)
condition = CoordinateCondition.CylinderCoordinateCondition(
center_position=[0.625, 0.5], radius=0.25, min_z=0, max_z=1
)
island_config = IslandSlabDefectConfiguration(
crystal=SLAB_111,
defect_type="island",
Expand Down

0 comments on commit 8ce0856

Please sign in to comment.