Skip to content

Commit

Permalink
update: use OOP and simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
VsevolodX committed Aug 9, 2024
1 parent a7c1993 commit 85f65af
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 93 deletions.
8 changes: 4 additions & 4 deletions src/py/mat3ra/made/tools/build/perturbation/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
from mat3ra.made.material import Material
from pydantic import BaseModel

from ...utils.functions import SineWavePerturbationFunctionHolder, GeneralPerturbationFunctionHolder
from ...utils.functions import SineWavePerturbationFunctionHolder, PerturbationFunctionHolder


class PerturbationConfiguration(BaseModel, InMemoryEntity):
material: Material
perturbation_function_holder: Union[
SineWavePerturbationFunctionHolder, GeneralPerturbationFunctionHolder
] = SineWavePerturbationFunctionHolder()
perturbation_function_holder: Union[SineWavePerturbationFunctionHolder, PerturbationFunctionHolder] = (
SineWavePerturbationFunctionHolder()
)
use_cartesian_coordinates: bool = True

class Config:
Expand Down
172 changes: 83 additions & 89 deletions src/py/mat3ra/made/tools/utils/functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, List, Literal
from typing import Any, Callable, List, Optional

import numpy as np
import sympy as sp
Expand All @@ -17,7 +17,7 @@ def apply_function(self, coordinate: List[float]) -> float:
"""
raise NotImplementedError

def apply_derivative(self, coordinate: List[float]) -> float:
def apply_derivative(self, coordinate: List[float], axis: str) -> float:
"""
Get the derivative of the function at the given coordinate
"""
Expand All @@ -36,81 +36,11 @@ def get_json(self) -> dict:
raise NotImplementedError


class PerturbationFunctionHolder(FunctionHolder):
def get_arc_length_equation(self, w_prime: float, w: float) -> float:
"""
Get the arc length equation for the perturbation function.
"""
arc_length = quad(
lambda t: np.sqrt(1 + (self.apply_derivative([t]) ** 2)),
a=0,
b=w_prime,
)[0]
return arc_length - w

def transform_coordinates(self, coordinate: List[float]) -> List[float]:
"""
Transform coordinates to preserve the distance between points on a sine wave when perturbation is applied.
Achieved by calculating the integral of the length between [0,0,0] and given coordinate.
Returns:
Callable[[List[float]], List[float]]: The coordinates transformation function.
"""
raise NotImplementedError

def apply_perturbation(self, coordinate: List[float]) -> List[float]:
"""
Apply the perturbation to the given coordinate.
"""
raise NotImplementedError


class SineWavePerturbationFunctionHolder(PerturbationFunctionHolder):
amplitude: float = 0.05
wavelength: float = 1
phase: float = 0
axis: Literal["x", "y"] = "x"

def apply_function(self, coordinate: List[float]) -> float:
w = coordinate[AXIS_TO_INDEX_MAP[self.axis]]
return self.amplitude * np.sin(2 * np.pi * w / self.wavelength + self.phase)

def apply_derivative(self, coordinate: List[float]) -> float:
w = coordinate[AXIS_TO_INDEX_MAP[self.axis]]
return self.amplitude * 2 * np.pi / self.wavelength * np.cos(2 * np.pi * w / self.wavelength + self.phase)

def apply_perturbation(self, coordinate: List[float]) -> List[float]:
return [coordinate[0], coordinate[1], coordinate[2] + self.apply_function(coordinate)]

def transform_coordinates(self, coordinate: List[float]) -> List[float]:
index = AXIS_TO_INDEX_MAP[self.axis]

w = coordinate[index]
# Find x' such that the integral from 0 to x' equals x
result = root_scalar(
self.get_arc_length_equation,
args=w,
bracket=[0, EQUATION_RANGE_COEFFICIENT * w],
method="brentq",
)
coordinate[index] = result.root
return coordinate

def get_json(self) -> dict:
return {
"type": self.__class__.__name__,
"amplitude": self.amplitude,
"wavelength": self.wavelength,
"phase": self.phase,
"axis": self.axis,
}


def default_function(coordinate: List[float]) -> float:
return 0


class GeneralPerturbationFunctionHolder(PerturbationFunctionHolder):
class PerturbationFunctionHolder(FunctionHolder):
variables: List[str] = ["x"]
symbols: List[sp.Symbol] = [sp.Symbol(var) for var in variables]
function: sp.Expr = sp.Symbol("f")
Expand All @@ -120,11 +50,14 @@ class GeneralPerturbationFunctionHolder(PerturbationFunctionHolder):
class Config:
arbitrary_types_allowed = True

def __init__(self, function: Callable, variables: List[str], **data: Any):
def __init__(self, function: Optional[Callable] = None, variables: Optional[List[str]] = None, **data: Any):
"""
Initializes with a function involving multiple variables.
"""

if function is None:
function = default_function
if variables is None:
variables = ["x"]
super().__init__(**data)
self.variables = variables
self.symbols = sp.symbols(variables)
Expand All @@ -135,12 +68,50 @@ def __init__(self, function: Callable, variables: List[str], **data: Any):
}

def apply_function(self, coordinate: List[float]) -> float:
values = [coordinate[{"x": 0, "y": 1, "z": 2}[var]] for var in self.variables]
values = [coordinate[AXIS_TO_INDEX_MAP[var]] for var in self.variables]
return self.function_numeric(*values)

def apply_derivative(self, coordinate: List[float], axis: str) -> float: # type: ignore
values = [coordinate[{"x": 0, "y": 1, "z": 2}[var]] for var in self.variables]
return self.derivatives_numeric[axis](*values)
def apply_derivative(self, coordinate: List[float], axis: str) -> float:
if axis in self.variables:
values = [coordinate[AXIS_TO_INDEX_MAP[var]] for var in self.variables]
return self.derivatives_numeric[axis](*values)
else:
return 0

def get_arc_length_equation(self, w_prime: float, coordinate: List[float], axis: str) -> float:
"""
Calculate arc length considering a change along one specific axis.
"""
index = AXIS_TO_INDEX_MAP[axis]
a, b = 0, w_prime # Integration limits based on the current position along the axis

def integrand(t):
temp_coordinate = coordinate[:]
temp_coordinate[index] = t
return np.sqrt(1 + self.apply_derivative(temp_coordinate, axis) ** 2)

arc_length = quad(integrand, a, b)[0]
return arc_length - coordinate[index]

def transform_coordinates(self, coordinate: List[float]) -> List[float]:
"""
Transform coordinates to preserve the distance between points on a sine wave when perturbation is applied.
Achieved by calculating the integral of the length between [0,0,0] and given coordinate.
Returns:
Callable[[List[float]], List[float]]: The coordinates transformation function.
"""
for i, var in enumerate(self.variables):
index = AXIS_TO_INDEX_MAP[var]
w = coordinate[index]
result = root_scalar(
self.get_arc_length_equation,
args=(coordinate, var),
bracket=[0, EQUATION_RANGE_COEFFICIENT * w],
method="brentq",
)
coordinate[index] = result.root
return coordinate

def apply_perturbation(self, coordinate: List[float]) -> List[float]:
"""
Expand All @@ -151,20 +122,43 @@ def apply_perturbation(self, coordinate: List[float]) -> List[float]:
perturbed_coordinate[2] += perturbation_value
return perturbed_coordinate

def transform_coordinates(self, coordinate: List[float]) -> List[float]:
for i, var in enumerate(self.variables):
def get_json(self) -> dict:
return {
"type": self.__class__.__name__,
"function": str(self.function),
"variables": self.variables,
}

def arc_length_eq(x_prime):
args = coordinate[:]
args[{"x": 0, "y": 1, "z": 2}[var]] = x_prime
return quad(lambda t: np.sqrt(1 + self.apply_derivative(args, var) ** 2), 0, x_prime)[0] - coordinate[i]

result = root_scalar(arc_length_eq, bracket=[0, coordinate[i] * 5], method="brentq")
coordinate[i] = result.root
return coordinate
class SineWavePerturbationFunctionHolder(PerturbationFunctionHolder):
amplitude: float = 0.05
wavelength: float = 1
phase: float = 0
axis: str = "x"

def __init__(
self,
amplitude: float = 0.05,
wavelength: float = 1,
phase: float = 0,
axis: str = "x",
**data: Any,
):
super().__init__(**data)
self.amplitude = amplitude
self.wavelength = wavelength
self.phase = phase
self.axis = axis
function = lambda x: self.amplitude * sp.sin(2 * sp.pi * x / self.wavelength + self.phase)

Check failure on line 152 in src/py/mat3ra/made/tools/utils/functions.py

View workflow job for this annotation

GitHub Actions / run-py-linter (3.8.6)

Ruff (E731)

src/py/mat3ra/made/tools/utils/functions.py:152:9: E731 Do not assign a `lambda` expression, use a `def`
variables = [self.axis]

PerturbationFunctionHolder.__init__(self, function=function, variables=variables)

def get_json(self) -> dict:
return {
"type": self.__class__.__name__,
"variables": self.variables,
"amplitude": self.amplitude,
"wavelength": self.wavelength,
"phase": self.phase,
"axis": self.axis,
}

0 comments on commit 85f65af

Please sign in to comment.