Skip to content

Commit

Permalink
feat: first implementation of general pertrubation
Browse files Browse the repository at this point in the history
  • Loading branch information
VsevolodX committed Aug 8, 2024
1 parent 963822a commit 028fa4b
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 3 deletions.
8 changes: 6 additions & 2 deletions src/py/mat3ra/made/tools/build/perturbation/configuration.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from typing import Union

from mat3ra.code.entity import InMemoryEntity
from mat3ra.made.material import Material
from pydantic import BaseModel

from ...utils.functions import SineWavePerturbationFunctionHolder
from ...utils.functions import SineWavePerturbationFunctionHolder, GeneralPerturbationFunctionHolder


class PerturbationConfiguration(BaseModel, InMemoryEntity):
material: Material
perturbation_function_holder: SineWavePerturbationFunctionHolder = SineWavePerturbationFunctionHolder()
perturbation_function_holder: Union[SineWavePerturbationFunctionHolder, GeneralPerturbationFunctionHolder] = (
SineWavePerturbationFunctionHolder()
)
use_cartesian_coordinates: bool = True

class Config:
Expand Down
65 changes: 64 additions & 1 deletion src/py/mat3ra/made/tools/utils/functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Literal
from typing import List, Literal, Callable, Any

import numpy as np
from pydantic import BaseModel
Expand Down Expand Up @@ -103,3 +103,66 @@ def get_json(self) -> dict:
"phase": self.phase,
"axis": self.axis,
}


import sympy as sp


class GeneralPerturbationFunctionHolder(PerturbationFunctionHolder):
variables: List[str] = ["x"]
symbols: List[sp.Symbol] = [sp.Symbol(var) for var in variables]
function: sp.Expr = sp.Symbol("f")
function_numeric: Callable = lambda x: x
derivatives_numeric: dict = {}

class Config:
arbitrary_types_allowed = True

def __init__(self, function: Callable, variables: List[str], **data: Any):
"""
Initializes with a function involving multiple variables.
"""

super().__init__(**data)
self.variables = variables
self.symbols = sp.symbols(variables)
self.function = function(*self.symbols)
self.function_numeric = sp.lambdify(self.symbols, self.function, modules=["numpy"])
self.derivatives_numeric = {
var: sp.lambdify(self.symbols, sp.diff(self.function, var), modules=["numpy"]) for var in variables
}

def apply_function(self, coordinate: List[float]) -> float:
values = [coordinate[{"x": 0, "y": 1, "z": 2}[var]] for var in self.variables]
return self.function_numeric(*values)

def apply_derivative(self, coordinate: List[float], axis: str) -> float: # type: ignore
values = [coordinate[{"x": 0, "y": 1, "z": 2}[var]] for var in self.variables]
return self.derivatives_numeric[axis](*values)

def apply_perturbation(self, coordinate: List[float]) -> List[float]:
"""
Apply the perturbation to the given coordinate by adding the function's value to the third coordinate (z-axis).
"""
perturbation_value = self.apply_function(coordinate)
perturbed_coordinate = coordinate[:]
perturbed_coordinate[2] += perturbation_value
return perturbed_coordinate

def transform_coordinates(self, coordinate: List[float]) -> List[float]:
for i, var in enumerate(self.variables):

def arc_length_eq(x_prime):
args = coordinate[:]
args[{"x": 0, "y": 1, "z": 2}[var]] = x_prime
return quad(lambda t: np.sqrt(1 + self.apply_derivative(args, var) ** 2), 0, x_prime)[0] - coordinate[i]

result = root_scalar(arc_length_eq, bracket=[0, coordinate[i] * 5], method="brentq")
coordinate[i] = result.root
return coordinate

def get_json(self) -> dict:
return {
"type": self.__class__.__name__,
"variables": self.variables,
}

0 comments on commit 028fa4b

Please sign in to comment.