diff --git a/elastica/rod/cosserat_rod.py b/elastica/rod/cosserat_rod.py index a6c626389..a7bba93fe 100644 --- a/elastica/rod/cosserat_rod.py +++ b/elastica/rod/cosserat_rod.py @@ -2,6 +2,7 @@ import numpy as np +from numpy.typing import NDArray import functools import numba from elastica.rod import RodBase @@ -147,39 +148,39 @@ class CosseratRod(RodBase, KnotTheory): def __init__( self, - n_elements, - position, - velocity, - omega, - acceleration, - angular_acceleration, - directors, - radius, - mass_second_moment_of_inertia, - inv_mass_second_moment_of_inertia, - shear_matrix, - bend_matrix, - density, - volume, - mass, - internal_forces, - internal_torques, - external_forces, - external_torques, - lengths, - rest_lengths, - tangents, - dilatation, - dilatation_rate, - voronoi_dilatation, - rest_voronoi_lengths, - sigma, - kappa, - rest_sigma, - rest_kappa, - internal_stress, - internal_couple, - ring_rod_flag, + n_elements: int, + position: NDArray[np.floating], + velocity: NDArray[np.floating], + omega: NDArray[np.floating], + acceleration: NDArray[np.floating], + angular_acceleration: NDArray[np.floating], + directors: NDArray[np.floating], + radius: NDArray[np.floating], + mass_second_moment_of_inertia: NDArray[np.floating], + inv_mass_second_moment_of_inertia: NDArray[np.floating], + shear_matrix: NDArray[np.floating], + bend_matrix: NDArray[np.floating], + density: NDArray[np.floating], + volume: NDArray[np.floating], + mass: NDArray[np.floating], + internal_forces: NDArray[np.floating], + internal_torques: NDArray[np.floating], + external_forces: NDArray[np.floating], + external_torques: NDArray[np.floating], + lengths: NDArray[np.floating], + rest_lengths: NDArray[np.floating], + tangents: NDArray[np.floating], + dilatation: NDArray[np.floating], + dilatation_rate: NDArray[np.floating], + voronoi_dilatation: NDArray[np.floating], + rest_voronoi_lengths: NDArray[np.floating], + sigma: NDArray[np.floating], + kappa: NDArray[np.floating], + rest_sigma: NDArray[np.floating], + rest_kappa: NDArray[np.floating], + internal_stress: NDArray[np.floating], + internal_couple: NDArray[np.floating], + ring_rod_flag: bool, ): self.n_elems = n_elements self.position_collection = position @@ -242,9 +243,9 @@ def __init__( def straight_rod( cls, n_elements: int, - start: np.ndarray, - direction: np.ndarray, - normal: np.ndarray, + start: NDArray[np.floating], + direction: NDArray[np.floating], + normal: NDArray[np.floating], base_length: float, base_radius: float, density: float, @@ -390,9 +391,9 @@ def straight_rod( def ring_rod( cls, n_elements: int, - ring_center_position: np.ndarray, - direction: np.ndarray, - normal: np.ndarray, + ring_center_position: NDArray[np.floating], + direction: NDArray[np.floating], + normal: NDArray[np.floating], base_length: float, base_radius: float, density: float, @@ -533,7 +534,7 @@ def ring_rod( ring_rod_flag, ) - def compute_internal_forces_and_torques(self, time): + def compute_internal_forces_and_torques(self, time: float): """ Compute internal forces and torques. We need to compute internal forces and torques before the acceleration because they are used in interaction. Thus in order to speed up simulation, we will compute internal forces and torques @@ -588,7 +589,7 @@ def compute_internal_forces_and_torques(self, time): ) # Interface to time-stepper mixins (Symplectic, Explicit), which calls this method - def update_accelerations(self, time): + def update_accelerations(self, time: float): """ Updates the acceleration variables @@ -610,7 +611,7 @@ def update_accelerations(self, time): self.dilatation, ) - def zeroed_out_external_forces_and_torques(self, time): + def zeroed_out_external_forces_and_torques(self, time: float): _zeroed_out_external_forces_and_torques( self.external_forces, self.external_torques ) diff --git a/elastica/rod/factory_function.py b/elastica/rod/factory_function.py index 6eca6b5ac..378398190 100644 --- a/elastica/rod/factory_function.py +++ b/elastica/rod/factory_function.py @@ -3,17 +3,18 @@ import logging import numpy as np from numpy.testing import assert_allclose +from numpy.typing import NDArray from elastica.utils import MaxDimension, Tolerance from elastica._linalg import _batch_cross, _batch_norm, _batch_dot def allocate( - n_elements, - direction, - normal, - base_length, - base_radius, - density, + n_elements: int, + direction: NDArray[np.floating], + normal: NDArray[np.floating], + base_length: float, + base_radius: float, + density: float, youngs_modulus: float, *, rod_origin_position: np.ndarray, @@ -335,14 +336,14 @@ def allocate( """ -def _assert_dim(vector, max_dim: int, name: str): +def _assert_dim(vector: np.ndarray, max_dim: int, name: str): assert vector.ndim < max_dim, ( f"Input {name} dimension is not correct {vector.shape}" + f" It should be maximum {max_dim}D vector or single floating number." ) -def _assert_shape(array: np.ndarray, expected_shape: Tuple[int], name: str): +def _assert_shape(array: np.ndarray, expected_shape: Tuple[int, ...], name: str): assert array.shape == expected_shape, ( f"Given {name} shape is not correct, it should be " + str(expected_shape) @@ -351,7 +352,9 @@ def _assert_shape(array: np.ndarray, expected_shape: Tuple[int], name: str): ) -def _position_validity_checker(position, start, n_elements): +def _position_validity_checker( + position: NDArray[np.floating], start: NDArray[np.floating], n_elements: int +): """Checker on user-defined position validity""" _assert_shape(position, (MaxDimension.value(), n_elements + 1), "position") @@ -367,7 +370,9 @@ def _position_validity_checker(position, start, n_elements): ) -def _directors_validity_checker(directors, tangents, n_elements): +def _directors_validity_checker( + directors: NDArray[np.floating], tangents: NDArray[np.floating], n_elements: int +): """Checker on user-defined directors validity""" _assert_shape( directors, (MaxDimension.value(), MaxDimension.value(), n_elements), "directors" @@ -413,7 +418,11 @@ def _directors_validity_checker(directors, tangents, n_elements): ) -def _position_validity_checker_ring_rod(position, ring_center_position, n_elements): +def _position_validity_checker_ring_rod( + position: NDArray[np.floating], + ring_center_position: NDArray[np.floating], + n_elements: int, +): """Checker on user-defined position validity""" _assert_shape(position, (MaxDimension.value(), n_elements), "position") diff --git a/elastica/rod/knot_theory.py b/elastica/rod/knot_theory.py index c66ff0e27..5c784a9e7 100644 --- a/elastica/rod/knot_theory.py +++ b/elastica/rod/knot_theory.py @@ -14,6 +14,7 @@ from numba import njit import numpy as np +from numpy.typing import NDArray from elastica.rod.rod_base import RodBase from elastica._linalg import _batch_norm, _batch_dot, _batch_cross @@ -138,7 +139,9 @@ def compute_link( )[0] -def compute_twist(center_line, normal_collection): +def compute_twist( + center_line: NDArray[np.floating], normal_collection: NDArray[np.floating] +): """ Compute the twist of a rod, using center_line and normal collection. @@ -189,7 +192,9 @@ def compute_twist(center_line, normal_collection): @njit(cache=True) -def _compute_twist(center_line, normal_collection): +def _compute_twist( + center_line: NDArray[np.floating], normal_collection: NDArray[np.floating] +): """ Parameters ---------- @@ -264,7 +269,11 @@ def _compute_twist(center_line, normal_collection): return total_twist, local_twist -def compute_writhe(center_line, segment_length, type_of_additional_segment): +def compute_writhe( + center_line: NDArray[np.floating], + segment_length: float, + type_of_additional_segment: str, +): """ This function computes the total writhe history of a rod. @@ -314,7 +323,7 @@ def compute_writhe(center_line, segment_length, type_of_additional_segment): @njit(cache=True) -def _compute_writhe(center_line): +def _compute_writhe(center_line: NDArray[np.floating]): """ Parameters ---------- @@ -386,9 +395,9 @@ def _compute_writhe(center_line): def compute_link( - center_line: np.ndarray, - normal_collection: np.ndarray, - radius: np.ndarray, + center_line: NDArray[np.floating], + normal_collection: NDArray[np.floating], + radius: NDArray[np.floating], segment_length: float, type_of_additional_segment: str, ): @@ -470,7 +479,11 @@ def compute_link( @njit(cache=True) -def _compute_auxiliary_line(center_line, normal_collection, radius): +def _compute_auxiliary_line( + center_line: NDArray[np.floating], + normal_collection: NDArray[np.floating], + radius: NDArray[np.floating], +): """ This function computes the auxiliary line using rod center line and normal collection. @@ -525,7 +538,9 @@ def _compute_auxiliary_line(center_line, normal_collection, radius): @njit(cache=True) -def _compute_link(center_line, auxiliary_line): +def _compute_link( + center_line: NDArray[np.floating], auxiliary_line: NDArray[np.floating] +): """ Parameters @@ -604,7 +619,10 @@ def _compute_link(center_line, auxiliary_line): @njit(cache=True) def _compute_auxiliary_line_added_segments( - beginning_direction, end_direction, auxiliary_line, segment_length + beginning_direction: NDArray[np.floating], + end_direction: NDArray[np.floating], + auxiliary_line: NDArray[np.floating], + segment_length: float, ): """ This code is for computing position of added segments to the auxiliary line. @@ -647,7 +665,9 @@ def _compute_auxiliary_line_added_segments( @njit(cache=True) def _compute_additional_segment( - center_line, segment_length, type_of_additional_segment + center_line: NDArray[np.floating], + segment_length: float, + type_of_additional_segment: str, ): """ This function adds two points at the end of center line. Distance from the center line is given by segment_length. diff --git a/elastica/rod/rod_base.py b/elastica/rod/rod_base.py index e3e22400c..28e20b708 100644 --- a/elastica/rod/rod_base.py +++ b/elastica/rod/rod_base.py @@ -1,5 +1,7 @@ __doc__ = """Base class for rods""" +import numpy as np +from numpy.typing import NDArray class RodBase: """ @@ -15,9 +17,9 @@ def __init__(self) -> None: """ RodBase does not take any arguments. """ - self.position_collection: int - self.omega_collection: int - self.acceleration_collection: int - self.alpha_collection: int - self.external_forces: int - self.external_torques: int + self.position_collection: NDArray[np.floating] + self.omega_collection: NDArray[np.floating] + self.acceleration_collection: NDArray[np.floating] + self.alpha_collection: NDArray[np.floating] + self.external_forces: NDArray[np.floating] + self.external_torques: NDArray[np.floating]