Skip to content

Commit

Permalink
update: control axes for pbc aug
Browse files Browse the repository at this point in the history
  • Loading branch information
VsevolodX committed Sep 6, 2024
1 parent f656e46 commit 4db6495
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions src/py/mat3ra/made/tools/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import wraps
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Tuple

import numpy as np
from mat3ra.made.material import Material
Expand Down Expand Up @@ -185,13 +185,16 @@ def filter_and_translate(coordinates: np.ndarray, elements: np.ndarray, axis: in
return translated_coordinates, filtered_elements


def augment_material_with_periodic_images(material: Material, cutoff: float = 0.1):
def augment_material_with_periodic_images(
material: Material, cutoff: float = 0.1, directions: Tuple[bool, bool, bool] = (True, True, True)
):
"""
Augment the material's dataset by adding atoms from periodic images near boundaries.
Args:
material (Material): The material to augment.
cutoff (float): The cutoff value for filtering atoms near boundaries.
directions (List[bool]): The directions to augment (flags for (x, y, z)).
Returns:
Tuple[Material, int]: The augmented material and the original count of atoms.
Expand All @@ -203,10 +206,13 @@ def augment_material_with_periodic_images(material: Material, cutoff: float = 0.
new_basis = augmented_material.basis.copy()

for axis in range(3):
for direction in [-1, 1]:
translated_coords, translated_elems = filter_and_translate(coordinates, elements, axis, cutoff, direction)
for coord, elem in zip(translated_coords, translated_elems):
new_basis.add_atom(elem, coord)
if directions[axis]:
for direction in [-1, 1]:
translated_coords, translated_elems = filter_and_translate(
coordinates, elements, axis, cutoff, direction
)
for coord, elem in zip(translated_coords, translated_elems):
new_basis.add_atom(elem, coord)

augmented_material.basis = new_basis
return augmented_material, last_id

0 comments on commit 4db6495

Please sign in to comment.