diff --git a/src/py/mat3ra/made/tools/build/perturbation/configuration.py b/src/py/mat3ra/made/tools/build/perturbation/configuration.py index 0692c2ed..31ab2067 100644 --- a/src/py/mat3ra/made/tools/build/perturbation/configuration.py +++ b/src/py/mat3ra/made/tools/build/perturbation/configuration.py @@ -4,14 +4,14 @@ from mat3ra.made.material import Material from pydantic import BaseModel -from ...utils.functions import SineWavePerturbationFunctionHolder, GeneralPerturbationFunctionHolder +from ...utils.functions import SineWavePerturbationFunctionHolder, PerturbationFunctionHolder class PerturbationConfiguration(BaseModel, InMemoryEntity): material: Material - perturbation_function_holder: Union[ - SineWavePerturbationFunctionHolder, GeneralPerturbationFunctionHolder - ] = SineWavePerturbationFunctionHolder() + perturbation_function_holder: Union[SineWavePerturbationFunctionHolder, PerturbationFunctionHolder] = ( + SineWavePerturbationFunctionHolder() + ) use_cartesian_coordinates: bool = True class Config: diff --git a/src/py/mat3ra/made/tools/utils/functions.py b/src/py/mat3ra/made/tools/utils/functions.py index 8e190275..dbdcb033 100644 --- a/src/py/mat3ra/made/tools/utils/functions.py +++ b/src/py/mat3ra/made/tools/utils/functions.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Literal +from typing import Any, Callable, List, Optional import numpy as np import sympy as sp @@ -17,7 +17,7 @@ def apply_function(self, coordinate: List[float]) -> float: """ raise NotImplementedError - def apply_derivative(self, coordinate: List[float]) -> float: + def apply_derivative(self, coordinate: List[float], axis: str) -> float: """ Get the derivative of the function at the given coordinate """ @@ -36,81 +36,11 @@ def get_json(self) -> dict: raise NotImplementedError -class PerturbationFunctionHolder(FunctionHolder): - def get_arc_length_equation(self, w_prime: float, w: float) -> float: - """ - Get the arc length equation for the perturbation function. - """ - arc_length = quad( - lambda t: np.sqrt(1 + (self.apply_derivative([t]) ** 2)), - a=0, - b=w_prime, - )[0] - return arc_length - w - - def transform_coordinates(self, coordinate: List[float]) -> List[float]: - """ - Transform coordinates to preserve the distance between points on a sine wave when perturbation is applied. - Achieved by calculating the integral of the length between [0,0,0] and given coordinate. - - Returns: - Callable[[List[float]], List[float]]: The coordinates transformation function. - """ - raise NotImplementedError - - def apply_perturbation(self, coordinate: List[float]) -> List[float]: - """ - Apply the perturbation to the given coordinate. - """ - raise NotImplementedError - - -class SineWavePerturbationFunctionHolder(PerturbationFunctionHolder): - amplitude: float = 0.05 - wavelength: float = 1 - phase: float = 0 - axis: Literal["x", "y"] = "x" - - def apply_function(self, coordinate: List[float]) -> float: - w = coordinate[AXIS_TO_INDEX_MAP[self.axis]] - return self.amplitude * np.sin(2 * np.pi * w / self.wavelength + self.phase) - - def apply_derivative(self, coordinate: List[float]) -> float: - w = coordinate[AXIS_TO_INDEX_MAP[self.axis]] - return self.amplitude * 2 * np.pi / self.wavelength * np.cos(2 * np.pi * w / self.wavelength + self.phase) - - def apply_perturbation(self, coordinate: List[float]) -> List[float]: - return [coordinate[0], coordinate[1], coordinate[2] + self.apply_function(coordinate)] - - def transform_coordinates(self, coordinate: List[float]) -> List[float]: - index = AXIS_TO_INDEX_MAP[self.axis] - - w = coordinate[index] - # Find x' such that the integral from 0 to x' equals x - result = root_scalar( - self.get_arc_length_equation, - args=w, - bracket=[0, EQUATION_RANGE_COEFFICIENT * w], - method="brentq", - ) - coordinate[index] = result.root - return coordinate - - def get_json(self) -> dict: - return { - "type": self.__class__.__name__, - "amplitude": self.amplitude, - "wavelength": self.wavelength, - "phase": self.phase, - "axis": self.axis, - } - - def default_function(coordinate: List[float]) -> float: return 0 -class GeneralPerturbationFunctionHolder(PerturbationFunctionHolder): +class PerturbationFunctionHolder(FunctionHolder): variables: List[str] = ["x"] symbols: List[sp.Symbol] = [sp.Symbol(var) for var in variables] function: sp.Expr = sp.Symbol("f") @@ -120,11 +50,14 @@ class GeneralPerturbationFunctionHolder(PerturbationFunctionHolder): class Config: arbitrary_types_allowed = True - def __init__(self, function: Callable, variables: List[str], **data: Any): + def __init__(self, function: Optional[Callable] = None, variables: Optional[List[str]] = None, **data: Any): """ Initializes with a function involving multiple variables. """ - + if function is None: + function = default_function + if variables is None: + variables = ["x"] super().__init__(**data) self.variables = variables self.symbols = sp.symbols(variables) @@ -135,12 +68,50 @@ def __init__(self, function: Callable, variables: List[str], **data: Any): } def apply_function(self, coordinate: List[float]) -> float: - values = [coordinate[{"x": 0, "y": 1, "z": 2}[var]] for var in self.variables] + values = [coordinate[AXIS_TO_INDEX_MAP[var]] for var in self.variables] return self.function_numeric(*values) - def apply_derivative(self, coordinate: List[float], axis: str) -> float: # type: ignore - values = [coordinate[{"x": 0, "y": 1, "z": 2}[var]] for var in self.variables] - return self.derivatives_numeric[axis](*values) + def apply_derivative(self, coordinate: List[float], axis: str) -> float: + if axis in self.variables: + values = [coordinate[AXIS_TO_INDEX_MAP[var]] for var in self.variables] + return self.derivatives_numeric[axis](*values) + else: + return 0 + + def get_arc_length_equation(self, w_prime: float, coordinate: List[float], axis: str) -> float: + """ + Calculate arc length considering a change along one specific axis. + """ + index = AXIS_TO_INDEX_MAP[axis] + a, b = 0, w_prime # Integration limits based on the current position along the axis + + def integrand(t): + temp_coordinate = coordinate[:] + temp_coordinate[index] = t + return np.sqrt(1 + self.apply_derivative(temp_coordinate, axis) ** 2) + + arc_length = quad(integrand, a, b)[0] + return arc_length - coordinate[index] + + def transform_coordinates(self, coordinate: List[float]) -> List[float]: + """ + Transform coordinates to preserve the distance between points on a sine wave when perturbation is applied. + Achieved by calculating the integral of the length between [0,0,0] and given coordinate. + + Returns: + Callable[[List[float]], List[float]]: The coordinates transformation function. + """ + for i, var in enumerate(self.variables): + index = AXIS_TO_INDEX_MAP[var] + w = coordinate[index] + result = root_scalar( + self.get_arc_length_equation, + args=(coordinate, var), + bracket=[0, EQUATION_RANGE_COEFFICIENT * w], + method="brentq", + ) + coordinate[index] = result.root + return coordinate def apply_perturbation(self, coordinate: List[float]) -> List[float]: """ @@ -151,20 +122,43 @@ def apply_perturbation(self, coordinate: List[float]) -> List[float]: perturbed_coordinate[2] += perturbation_value return perturbed_coordinate - def transform_coordinates(self, coordinate: List[float]) -> List[float]: - for i, var in enumerate(self.variables): + def get_json(self) -> dict: + return { + "type": self.__class__.__name__, + "function": str(self.function), + "variables": self.variables, + } - def arc_length_eq(x_prime): - args = coordinate[:] - args[{"x": 0, "y": 1, "z": 2}[var]] = x_prime - return quad(lambda t: np.sqrt(1 + self.apply_derivative(args, var) ** 2), 0, x_prime)[0] - coordinate[i] - result = root_scalar(arc_length_eq, bracket=[0, coordinate[i] * 5], method="brentq") - coordinate[i] = result.root - return coordinate +class SineWavePerturbationFunctionHolder(PerturbationFunctionHolder): + amplitude: float = 0.05 + wavelength: float = 1 + phase: float = 0 + axis: str = "x" + + def __init__( + self, + amplitude: float = 0.05, + wavelength: float = 1, + phase: float = 0, + axis: str = "x", + **data: Any, + ): + super().__init__(**data) + self.amplitude = amplitude + self.wavelength = wavelength + self.phase = phase + self.axis = axis + function = lambda x: self.amplitude * sp.sin(2 * sp.pi * x / self.wavelength + self.phase) + variables = [self.axis] + + PerturbationFunctionHolder.__init__(self, function=function, variables=variables) def get_json(self) -> dict: return { "type": self.__class__.__name__, - "variables": self.variables, + "amplitude": self.amplitude, + "wavelength": self.wavelength, + "phase": self.phase, + "axis": self.axis, }