From d294e44914e15c21603738adffd73fd82b1eee7b Mon Sep 17 00:00:00 2001 From: Ankith Date: Sat, 11 May 2024 00:00:06 +0530 Subject: [PATCH] Improve typehinting to files at project root --- elastica/_calculus.py | 34 +++- elastica/_contact_functions.py | 23 +-- elastica/_linalg.py | 45 +++-- elastica/_rotations.py | 49 ++--- elastica/_synchronize_periodic_boundary.py | 23 ++- elastica/boundary_conditions.py | 115 ++++++----- elastica/callback_functions.py | 25 +-- elastica/contact_forces.py | 49 ++--- elastica/contact_utils.py | 53 +++-- elastica/dissipation.py | 25 ++- elastica/external_forces.py | 79 +++++--- elastica/interaction.py | 168 +++++++++------- elastica/joint.py | 219 ++++++++++++--------- elastica/restart.py | 20 +- elastica/rod/cosserat_rod.py | 99 +++++----- elastica/rod/data_structures.py | 8 +- elastica/rod/factory_function.py | 36 ++-- elastica/rod/knot_theory.py | 51 +++-- elastica/rod/rod_base.py | 11 +- elastica/transformations.py | 22 ++- elastica/typing.py | 7 +- elastica/utils.py | 30 ++- 22 files changed, 731 insertions(+), 460 deletions(-) diff --git a/elastica/_calculus.py b/elastica/_calculus.py index eca9829b7..4a4121063 100644 --- a/elastica/_calculus.py +++ b/elastica/_calculus.py @@ -1,23 +1,29 @@ __doc__ = """ Quadrature and difference kernels """ +from typing import Any, Union import numpy as np from numpy import zeros, empty +from numpy.typing import NDArray from numba import njit from elastica.reset_functions_for_block_structure._reset_ghost_vector_or_scalar import ( _reset_vector_ghost, ) import functools +from elastica.typing import Float + @functools.lru_cache(maxsize=2) -def _get_zero_array(dim, ndim): +def _get_zero_array(dim: int, ndim: int) -> Union[float, NDArray[np.floating], None]: if ndim == 1: return 0.0 if ndim == 2: return np.zeros((dim, 1)) + return None + @njit(cache=True) -def _trapezoidal(array_collection): +def _trapezoidal(array_collection: NDArray[np.floating]) -> NDArray[np.floating]: """ Simple trapezoidal quadrature rule with zero at end-points, in a dimension agnostic way @@ -63,7 +69,9 @@ def _trapezoidal(array_collection): @njit(cache=True) -def _trapezoidal_for_block_structure(array_collection, ghost_idx): +def _trapezoidal_for_block_structure( + array_collection: NDArray[np.floating], ghost_idx: NDArray[np.integer] +) -> NDArray[np.floating]: """ Simple trapezoidal quadrature rule with zero at end-points, in a dimension agnostic way. This form specifically for the block structure implementation and there is a reset function call, to reset @@ -115,7 +123,9 @@ def _trapezoidal_for_block_structure(array_collection, ghost_idx): @njit(cache=True) -def _two_point_difference(array_collection): +def _two_point_difference( + array_collection: NDArray[np.floating], +) -> NDArray[np.floating]: """ This function does differentiation. @@ -156,7 +166,9 @@ def _two_point_difference(array_collection): @njit(cache=True) -def _two_point_difference_for_block_structure(array_collection, ghost_idx): +def _two_point_difference_for_block_structure( + array_collection: NDArray[np.floating], ghost_idx: NDArray[np.integer] +) -> NDArray[np.floating]: """ This function does the differentiation, for Cosserat rod model equations. This form specifically for the block structure implementation and there is a reset function call, to @@ -207,7 +219,7 @@ def _two_point_difference_for_block_structure(array_collection, ghost_idx): @njit(cache=True) -def _difference(vector): +def _difference(vector: NDArray[np.floating]) -> NDArray[np.floating]: """ This function computes difference between elements of a batch vector. @@ -238,7 +250,7 @@ def _difference(vector): @njit(cache=True) -def _average(vector): +def _average(vector: NDArray[np.floating]) -> NDArray[np.floating]: """ This function computes the average between elements of a vector. @@ -268,7 +280,9 @@ def _average(vector): @njit(cache=True) -def _clip_array(input_array, vmin, vmax): +def _clip_array( + input_array: NDArray[np.floating], vmin: Float, vmax: Float +) -> NDArray[np.floating]: """ This function clips an array values between user defined minimum and maximum @@ -304,7 +318,7 @@ def _clip_array(input_array, vmin, vmax): @njit(cache=True) -def _isnan_check(array): +def _isnan_check(array: NDArray[Any]) -> bool: """ This function checks if there is any nan inside the array. If there is nan, it returns True boolean. @@ -324,7 +338,7 @@ def _isnan_check(array): Python version: 2.24 µs ± 96.1 ns per loop This version: 479 ns ± 6.49 ns per loop """ - return np.isnan(array).any() + return bool(np.isnan(array).any()) position_difference_kernel = _difference diff --git a/elastica/_contact_functions.py b/elastica/_contact_functions.py index 245d92319..5aec12219 100644 --- a/elastica/_contact_functions.py +++ b/elastica/_contact_functions.py @@ -24,7 +24,9 @@ ) import numba import numpy as np +from numpy.typing import NDArray +from elastica.typing import Float @numba.njit(cache=True) def _calculate_contact_forces_rod_cylinder( @@ -784,17 +786,16 @@ def _calculate_contact_forces_rod_plane_with_anisotropic_friction( @numba.njit(cache=True) def _calculate_contact_forces_cylinder_plane( - plane_origin, - plane_normal, - surface_tol, - k, - nu, - length, - position_collection, - velocity_collection, - external_forces, -): - + plane_origin: NDArray[np.floating], + plane_normal: NDArray[np.floating], + surface_tol: Float, + k: Float, + nu: Float, + length: NDArray[np.floating], + position_collection: NDArray[np.floating], + velocity_collection: NDArray[np.floating], + external_forces: NDArray[np.floating], +) -> tuple[NDArray[np.floating], NDArray[np.intp]]: # Compute plane response force # total_forces = system.internal_forces + system.external_forces total_forces = external_forces diff --git a/elastica/_linalg.py b/elastica/_linalg.py index a4995ab25..3123ff75b 100644 --- a/elastica/_linalg.py +++ b/elastica/_linalg.py @@ -1,5 +1,6 @@ __doc__ = """ Convenient linear algebra kernels """ import numpy as np +from numpy.typing import NDArray from numba import njit from numpy import sqrt import functools @@ -8,7 +9,7 @@ @functools.lru_cache(maxsize=1) -def levi_civita_tensor(dim): +def levi_civita_tensor(dim: int) -> NDArray[np.floating]: """ Parameters @@ -28,7 +29,9 @@ def levi_civita_tensor(dim): @njit(cache=True) -def _batch_matvec(matrix_collection, vector_collection): +def _batch_matvec( + matrix_collection: NDArray[np.floating], vector_collection: NDArray[np.floating] +) -> NDArray[np.floating]: """ This function does batch matrix and batch vector product @@ -59,7 +62,10 @@ def _batch_matvec(matrix_collection, vector_collection): @njit(cache=True) -def _batch_matmul(first_matrix_collection, second_matrix_collection): +def _batch_matmul( + first_matrix_collection: NDArray[np.floating], + second_matrix_collection: NDArray[np.floating], +) -> NDArray[np.floating]: """ This is batch matrix matrix multiplication function. Only batch of 3x3 matrices can be multiplied. @@ -93,7 +99,10 @@ def _batch_matmul(first_matrix_collection, second_matrix_collection): @njit(cache=True) -def _batch_cross(first_vector_collection, second_vector_collection): +def _batch_cross( + first_vector_collection: NDArray[np.floating], + second_vector_collection: NDArray[np.floating], +) -> NDArray[np.floating]: """ This function does cross product between two batch vectors. @@ -133,7 +142,9 @@ def _batch_cross(first_vector_collection, second_vector_collection): @njit(cache=True) -def _batch_vec_oneD_vec_cross(first_vector_collection, second_vector): +def _batch_vec_oneD_vec_cross( + first_vector_collection: NDArray[np.floating], second_vector: NDArray[np.floating] +) -> NDArray[np.floating]: """ This function does cross product between batch vector and a 1D vector. Idea of having this function is that, for friction calculations, we dont @@ -177,7 +188,9 @@ def _batch_vec_oneD_vec_cross(first_vector_collection, second_vector): @njit(cache=True) -def _batch_dot(first_vector, second_vector): +def _batch_dot( + first_vector: NDArray[np.floating], second_vector: NDArray[np.floating] +) -> NDArray[np.floating]: """ This function does batch vec and batch vec dot product. Parameters @@ -204,7 +217,7 @@ def _batch_dot(first_vector, second_vector): @njit(cache=True) -def _batch_norm(vector): +def _batch_norm(vector: NDArray[np.floating]) -> NDArray[np.floating]: """ This function computes norm of a batch vector Parameters @@ -233,7 +246,9 @@ def _batch_norm(vector): @njit(cache=True) -def _batch_product_i_k_to_ik(vector1, vector2): +def _batch_product_i_k_to_ik( + vector1: NDArray[np.floating], vector2: NDArray[np.floating] +) -> NDArray[np.floating]: """ This function does outer product following 'i,k->ik'. vector1 has shape of 3 and vector 2 has shape of blocksize @@ -262,7 +277,9 @@ def _batch_product_i_k_to_ik(vector1, vector2): @njit(cache=True) -def _batch_product_i_ik_to_k(vector1, vector2): +def _batch_product_i_ik_to_k( + vector1: NDArray[np.floating], vector2: NDArray[np.floating] +) -> NDArray[np.floating]: """ This function does the following product 'i,ik->k' This function do dot product between a vector of 3 elements @@ -293,7 +310,9 @@ def _batch_product_i_ik_to_k(vector1, vector2): @njit(cache=True) -def _batch_product_k_ik_to_ik(vector1, vector2): +def _batch_product_k_ik_to_ik( + vector1: NDArray[np.floating], vector2: NDArray[np.floating] +) -> NDArray[np.floating]: """ This function does the following product 'k, ik->ik' Parameters @@ -322,7 +341,9 @@ def _batch_product_k_ik_to_ik(vector1, vector2): @njit(cache=True) -def _batch_vector_sum(vector1, vector2): +def _batch_vector_sum( + vector1: NDArray[np.floating], vector2: NDArray[np.floating] +) -> NDArray[np.floating]: """ This function is for summing up two vectors. Although this function is not faster than pure python implementation @@ -352,7 +373,7 @@ def _batch_vector_sum(vector1, vector2): @njit(cache=True) -def _batch_matrix_transpose(input_matrix): +def _batch_matrix_transpose(input_matrix: NDArray[np.floating]) -> NDArray[np.floating]: """ This function takes an batch input matrix and transpose it. Parameters diff --git a/elastica/_rotations.py b/elastica/_rotations.py index 25ec14212..6ce664d80 100644 --- a/elastica/_rotations.py +++ b/elastica/_rotations.py @@ -8,14 +8,18 @@ from numpy import cos from numpy import sqrt from numpy import arccos +from numpy.typing import NDArray from numba import njit from elastica._linalg import _batch_matmul +from elastica.typing import Float @njit(cache=True) -def _get_rotation_matrix(scale: float, axis_collection): +def _get_rotation_matrix( + scale: Float, axis_collection: NDArray[np.floating] +) -> NDArray[np.floating]: blocksize = axis_collection.shape[1] rot_mat = np.empty((3, 3, blocksize)) @@ -49,7 +53,11 @@ def _get_rotation_matrix(scale: float, axis_collection): @njit(cache=True) -def _rotate(director_collection, scale: float, axis_collection): +def _rotate( + director_collection: NDArray[np.floating], + scale: Float, + axis_collection: NDArray[np.floating], +) -> NDArray[np.floating]: """ Does alibi rotations https://en.wikipedia.org/wiki/Rotation_matrix#Ambiguities @@ -74,7 +82,7 @@ def _rotate(director_collection, scale: float, axis_collection): @njit(cache=True) -def _inv_rotate(director_collection): +def _inv_rotate(director_collection: NDArray[np.floating]) -> NDArray[np.floating]: """ Calculated rate of change using Rodrigues' formula @@ -156,12 +164,15 @@ def _inv_rotate(director_collection): return vector_collection +_generate_skew_map_sentinel = (0, 0, 0) + + # TODO: Below contains numpy-only implementations @functools.lru_cache(maxsize=1) -def _generate_skew_map(dim: int): +def _generate_skew_map(dim: int) -> list[tuple[int, int, int]]: # TODO Documentation # Preallocate - mapping_list = [None] * ((dim**2 - dim) // 2) + mapping_list = [_generate_skew_map_sentinel] * ((dim**2 - dim) // 2) # Indexing (i,j), j is the fastest changing # r = 2, r here is rank, we deal with only matrices for index, (i, j) in enumerate(combinations(range(dim), r=2)): @@ -185,7 +196,7 @@ def _generate_skew_map(dim: int): @functools.lru_cache(maxsize=1) -def _get_skew_map(dim): +def _get_skew_map(dim: int) -> tuple[tuple[int, int, int], ...]: """Generates mapping from src to target skew-symmetric operator For input vector V and output Matrix M (represented in lexicographical index), @@ -208,7 +219,7 @@ def _get_skew_map(dim): @functools.lru_cache(maxsize=1) -def _get_inv_skew_map(dim): +def _get_inv_skew_map(dim: int) -> tuple[tuple[int, int, int], ...]: # TODO Documentation # (vec_src, mat_i, mat_j, sign) mapping_list = _generate_skew_map(dim) @@ -219,7 +230,7 @@ def _get_inv_skew_map(dim): @functools.lru_cache(maxsize=1) -def _get_diag_map(dim): +def _get_diag_map(dim: int) -> tuple[int, ...]: """Generates lexicographic mapping to diagonal in a serialized matrix-type For input dimension dim we calculate mapping to * in Matrix M below @@ -231,17 +242,10 @@ def _get_diag_map(dim): in a dimension agnostic way. """ - # Preallocate - mapping_list = [None] * dim - - # Store linear indices - for dim_iter in range(dim): - mapping_list[dim_iter] = dim_iter * (dim + 1) - - return tuple(mapping_list) + return tuple([dim_iter * (dim + 1) for dim_iter in range(dim)]) -def _skew_symmetrize(vector): +def _skew_symmetrize(vector: NDArray[np.floating]) -> NDArray[np.floating]: """ Parameters @@ -276,7 +280,7 @@ def _skew_symmetrize(vector): # This is purely for testing and optimization sake # While calculating u^2, use u with einsum instead, as it is tad bit faster -def _skew_symmetrize_sq(vector): +def _skew_symmetrize_sq(vector: NDArray[np.floating]) -> NDArray[np.floating]: """ Generate the square of an orthogonal matrix from vector elements @@ -298,12 +302,11 @@ def _skew_symmetrize_sq(vector): hardcoded : 23.1 µs ± 481 ns per loop this version: 14.1 µs ± 96.9 ns per loop """ - dim, _ = vector.shape # First generate array of [x^2, xy, xz, yx, y^2, yz, zx, zy, z^2] # across blocksize # This is slightly faster than doing v[np.newaxis,:,:] * v[:,np.newaxis,:] - products_xy = np.einsum("ik,jk->ijk", vector, vector) + products_xy: NDArray[np.floating] = np.einsum("ik,jk->ijk", vector, vector) # No copy made here, as we do not change memory layout # products_xy = products_xy.reshape((dim * dim, -1)) @@ -335,7 +338,9 @@ def _skew_symmetrize_sq(vector): return products_xy -def _get_skew_symmetric_pair(vector_collection): +def _get_skew_symmetric_pair( + vector_collection: NDArray[np.floating], +) -> tuple[NDArray[np.floating], NDArray[np.floating]]: """ Parameters @@ -351,7 +356,7 @@ def _get_skew_symmetric_pair(vector_collection): return u, u_sq -def _inv_skew_symmetrize(matrix): +def _inv_skew_symmetrize(matrix: NDArray[np.floating]) -> NDArray[np.floating]: """ Return the vector elements from a skew-symmetric matrix M diff --git a/elastica/_synchronize_periodic_boundary.py b/elastica/_synchronize_periodic_boundary.py index b4fe87b4f..73787a8bb 100644 --- a/elastica/_synchronize_periodic_boundary.py +++ b/elastica/_synchronize_periodic_boundary.py @@ -2,12 +2,18 @@ """These functions are used to synchronize periodic boundaries for ring rods. """ ) +from typing import Any from numba import njit +import numpy as np +from numpy.typing import NDArray from elastica.boundary_conditions import ConstraintBase +from elastica.typing import Float, SystemType @njit(cache=True) -def _synchronize_periodic_boundary_of_vector_collection(input, periodic_idx): +def _synchronize_periodic_boundary_of_vector_collection( + input: NDArray[np.floating], periodic_idx: NDArray[np.floating] +) -> None: """ This function synchronizes the periodic boundaries of a vector collection. Parameters @@ -28,7 +34,9 @@ def _synchronize_periodic_boundary_of_vector_collection(input, periodic_idx): @njit(cache=True) -def _synchronize_periodic_boundary_of_matrix_collection(input, periodic_idx): +def _synchronize_periodic_boundary_of_matrix_collection( + input: NDArray[np.floating], periodic_idx: NDArray[np.floating] +) -> None: """ This function synchronizes the periodic boundaries of a matrix collection. Parameters @@ -50,7 +58,9 @@ def _synchronize_periodic_boundary_of_matrix_collection(input, periodic_idx): @njit(cache=True) -def _synchronize_periodic_boundary_of_scalar_collection(input, periodic_idx): +def _synchronize_periodic_boundary_of_scalar_collection( + input: NDArray[np.floating], periodic_idx: NDArray[np.floating] +) -> None: """ This function synchronizes the periodic boundaries of a scalar collection. @@ -76,10 +86,11 @@ class _ConstrainPeriodicBoundaries(ConstraintBase): is to synchronize periodic boundaries of ring rod. """ - def __init__(self, **kwargs): + # TODO: improve typing + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - def constrain_values(self, rod, time): + def constrain_values(self, rod: SystemType, time: Float) -> None: _synchronize_periodic_boundary_of_vector_collection( rod.position_collection, rod.periodic_boundary_nodes_idx ) @@ -87,7 +98,7 @@ def constrain_values(self, rod, time): rod.director_collection, rod.periodic_boundary_elems_idx ) - def constrain_rates(self, rod, time): + def constrain_rates(self, rod: SystemType, time: Float) -> None: _synchronize_periodic_boundary_of_vector_collection( rod.velocity_collection, rod.periodic_boundary_nodes_idx ) diff --git a/elastica/boundary_conditions.py b/elastica/boundary_conditions.py index 0cf666c8c..6dc2f7ab3 100644 --- a/elastica/boundary_conditions.py +++ b/elastica/boundary_conditions.py @@ -1,9 +1,10 @@ __doc__ = """ Built-in boundary condition implementationss """ import warnings -from typing import Optional +from typing import Any, Optional, Tuple import numpy as np +from numpy.typing import NDArray from abc import ABC, abstractmethod @@ -11,7 +12,7 @@ from elastica._linalg import _batch_matvec, _batch_matrix_transpose from elastica._rotations import _get_rotation_matrix -from elastica.typing import SystemType, RodType +from elastica.typing import Float, SystemType, RodType class ConstraintBase(ABC): @@ -34,7 +35,7 @@ class ConstraintBase(ABC): _constrained_position_idx: np.ndarray _constrained_director_idx: np.ndarray - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: """Initialize boundary condition""" try: self._system = kwargs["_system"] @@ -67,7 +68,7 @@ def constrained_director_idx(self) -> Optional[np.ndarray]: return self._constrained_director_idx @abstractmethod - def constrain_values(self, system: SystemType, time: float) -> None: + def constrain_values(self, system: SystemType, time: Float) -> None: # TODO: In the future, we can remove rod and use self.system """ Constrain values (position and/or directors) of a rod object. @@ -82,7 +83,7 @@ def constrain_values(self, system: SystemType, time: float) -> None: pass @abstractmethod - def constrain_rates(self, system: SystemType, time: float) -> None: + def constrain_rates(self, system: SystemType, time: Float) -> None: # TODO: In the future, we can remove rod and use self.system """ Constrain rates (velocity and/or omega) of a rod object. @@ -103,14 +104,14 @@ class FreeBC(ConstraintBase): Boundary condition template. """ - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - def constrain_values(self, system: SystemType, time: float) -> None: + def constrain_values(self, system: SystemType, time: Float) -> None: """In FreeBC, this routine simply passes.""" pass - def constrain_rates(self, system: SystemType, time: float) -> None: + def constrain_rates(self, system: SystemType, time: Float) -> None: """In FreeBC, this routine simply passes.""" pass @@ -143,7 +144,12 @@ class OneEndFixedBC(ConstraintBase): ... ) """ - def __init__(self, fixed_position, fixed_directors, **kwargs): + def __init__( + self, + fixed_position: Tuple[int, ...], + fixed_directors: Tuple[int, ...], + **kwargs: Any, + ) -> None: """ Initialization of the constraint. Any parameter passed to 'using' will be available in kwargs. @@ -159,7 +165,7 @@ def __init__(self, fixed_position, fixed_directors, **kwargs): self.fixed_position_collection = np.array(fixed_position) self.fixed_directors_collection = np.array(fixed_directors) - def constrain_values(self, system: SystemType, time: float) -> None: + def constrain_values(self, system: SystemType, time: Float) -> None: # system.position_collection[..., 0] = self.fixed_position # system.director_collection[..., 0] = self.fixed_directors self.compute_constrain_values( @@ -169,7 +175,7 @@ def constrain_values(self, system: SystemType, time: float) -> None: self.fixed_directors_collection, ) - def constrain_rates(self, system: SystemType, time: float) -> None: + def constrain_rates(self, system: SystemType, time: Float) -> None: # system.velocity_collection[..., 0] = 0.0 # system.omega_collection[..., 0] = 0.0 self.compute_constrain_rates( @@ -180,11 +186,11 @@ def constrain_rates(self, system: SystemType, time: float) -> None: @staticmethod @njit(cache=True) def compute_constrain_values( - position_collection, - fixed_position_collection, - director_collection, - fixed_directors_collection, - ): + position_collection: NDArray[np.floating], + fixed_position_collection: NDArray[np.floating], + director_collection: NDArray[np.floating], + fixed_directors_collection: NDArray[np.floating], + ) -> None: """ Computes constrain values in numba njit decorator @@ -208,7 +214,10 @@ def compute_constrain_values( @staticmethod @njit(cache=True) - def compute_constrain_rates(velocity_collection, omega_collection): + def compute_constrain_rates( + velocity_collection: NDArray[np.floating], + omega_collection: NDArray[np.floating], + ) -> None: """ Compute contrain rates in numba njit decorator @@ -266,11 +275,11 @@ class GeneralConstraint(ConstraintBase): def __init__( self, - *fixed_data, - translational_constraint_selector: Optional[np.ndarray] = None, - rotational_constraint_selector: Optional[np.array] = None, - **kwargs, - ): + *fixed_data: Any, + translational_constraint_selector: Optional[NDArray[np.bool_]] = None, + rotational_constraint_selector: Optional[NDArray[np.bool_]] = None, + **kwargs: Any, + ) -> None: """ Initialization of the constraint. Any parameter passed to 'using' will be available in kwargs. @@ -331,7 +340,7 @@ def __init__( ) self.rotational_constraint_selector = rotational_constraint_selector.astype(int) - def constrain_values(self, system: SystemType, time: float) -> None: + def constrain_values(self, system: SystemType, time: Float) -> None: if self.constrained_position_idx.size: self.nb_constrain_translational_values( system.position_collection, @@ -340,7 +349,7 @@ def constrain_values(self, system: SystemType, time: float) -> None: self.translational_constraint_selector, ) - def constrain_rates(self, system: SystemType, time: float) -> None: + def constrain_rates(self, system: SystemType, time: Float) -> None: if self.constrained_position_idx.size: self.nb_constrain_translational_rates( system.velocity_collection, @@ -358,7 +367,10 @@ def constrain_rates(self, system: SystemType, time: float) -> None: @staticmethod @njit(cache=True) def nb_constrain_translational_values( - position_collection, fixed_position_collection, indices, constraint_selector + position_collection: NDArray[np.floating], + fixed_position_collection: NDArray[np.floating], + indices: NDArray[np.integer], + constraint_selector: NDArray[np.integer], ) -> None: """ Computes constrain values in numba njit decorator @@ -393,7 +405,9 @@ def nb_constrain_translational_values( @staticmethod @njit(cache=True) def nb_constrain_translational_rates( - velocity_collection, indices, constraint_selector + velocity_collection: NDArray[np.floating], + indices: NDArray[np.integer], + constraint_selector: NDArray[np.integer], ) -> None: """ Compute constrain rates in numba njit decorator @@ -422,7 +436,10 @@ def nb_constrain_translational_rates( @staticmethod @njit(cache=True) def nb_constrain_rotational_rates( - director_collection, omega_collection, indices, constraint_selector + director_collection: NDArray[np.floating], + omega_collection: NDArray[np.floating], + indices: NDArray[np.integer], + constraint_selector: NDArray[np.integer], ) -> None: """ Compute constrain rates in numba njit decorator @@ -489,7 +506,7 @@ class FixedConstraint(GeneralConstraint): GeneralConstraint: Generalized constraint with configurable DOF. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: """ Initialization of the constraint. Any parameter passed to 'using' will be available in kwargs. @@ -508,7 +525,7 @@ def __init__(self, *args, **kwargs): **kwargs, ) - def constrain_values(self, system: SystemType, time: float) -> None: + def constrain_values(self, system: SystemType, time: Float) -> None: if self.constrained_position_idx.size: self.nb_constrain_translational_values( system.position_collection, @@ -522,7 +539,7 @@ def constrain_values(self, system: SystemType, time: float) -> None: self.constrained_director_idx, ) - def constrain_rates(self, system: SystemType, time: float) -> None: + def constrain_rates(self, system: SystemType, time: Float) -> None: if self.constrained_position_idx.size: self.nb_constrain_translational_rates( system.velocity_collection, @@ -537,7 +554,9 @@ def constrain_rates(self, system: SystemType, time: float) -> None: @staticmethod @njit(cache=True) def nb_constraint_rotational_values( - director_collection, fixed_director_collection, indices + director_collection: NDArray[np.floating], + fixed_director_collection: NDArray[np.floating], + indices: NDArray[np.integer], ) -> None: """ Computes constrain values in numba njit decorator @@ -558,7 +577,9 @@ def nb_constraint_rotational_values( @staticmethod @njit(cache=True) def nb_constrain_translational_values( - position_collection, fixed_position_collection, indices + position_collection: NDArray[np.floating], + fixed_position_collection: NDArray[np.floating], + indices: NDArray[np.integer], ) -> None: """ Computes constrain values in numba njit decorator @@ -578,7 +599,9 @@ def nb_constrain_translational_values( @staticmethod @njit(cache=True) - def nb_constrain_translational_rates(velocity_collection, indices) -> None: + def nb_constrain_translational_rates( + velocity_collection: NDArray[np.floating], indices: NDArray[np.integer] + ) -> None: """ Compute constrain rates in numba njit decorator Parameters @@ -598,7 +621,9 @@ def nb_constrain_translational_rates(velocity_collection, indices) -> None: @staticmethod @njit(cache=True) - def nb_constrain_rotational_rates(omega_collection, indices) -> None: + def nb_constrain_rotational_rates( + omega_collection: NDArray[np.floating], indices: NDArray[np.integer] + ) -> None: """ Compute constrain rates in numba njit decorator Parameters @@ -654,15 +679,15 @@ class HelicalBucklingBC(ConstraintBase): def __init__( self, - position_start: np.ndarray, - position_end: np.ndarray, - director_start: np.ndarray, - director_end: np.ndarray, - twisting_time: float, - slack: float, - number_of_rotations: float, - **kwargs, - ): + position_start: NDArray[np.floating], + position_end: NDArray[np.floating], + director_start: NDArray[np.floating], + director_end: NDArray[np.floating], + twisting_time: Float, + slack: Float, + number_of_rotations: Float, + **kwargs: Any, + ) -> None: """ Helical Buckling initializer @@ -718,7 +743,7 @@ def __init__( @ director_end ) # rotation_matrix wants vectors 3,1 - def constrain_values(self, rod: RodType, time: float) -> None: + def constrain_values(self, rod: RodType, time: Float) -> None: if time > self.twisting_time: rod.position_collection[..., 0] = self.final_start_position rod.position_collection[..., -1] = self.final_end_position @@ -726,7 +751,7 @@ def constrain_values(self, rod: RodType, time: float) -> None: rod.director_collection[..., 0] = self.final_start_directors rod.director_collection[..., -1] = self.final_end_directors - def constrain_rates(self, rod: RodType, time: float) -> None: + def constrain_rates(self, rod: RodType, time: Float) -> None: if time > self.twisting_time: rod.velocity_collection[..., 0] = 0.0 rod.omega_collection[..., 0] = 0.0 diff --git a/elastica/callback_functions.py b/elastica/callback_functions.py index ab865dfd0..5d487805d 100644 --- a/elastica/callback_functions.py +++ b/elastica/callback_functions.py @@ -4,9 +4,12 @@ import sys import numpy as np import logging +from typing import Any, Optional from collections import defaultdict +from elastica.typing import Float, RodType, SystemType + class CallBackBaseClass: """ @@ -19,13 +22,13 @@ class CallBackBaseClass: """ - def __init__(self): + def __init__(self) -> None: """ CallBackBaseClass does not need any input parameters. """ pass - def make_callback(self, system, time, current_step: int): + def make_callback(self, syste: RodType, time: Float, current_step: int) -> None: """ This method is called every time step. Users can define which parameters are called back and recorded. Also users @@ -59,7 +62,7 @@ class MyCallBack(CallBackBaseClass): Collected callback data is saved in this dictionary. """ - def __init__(self, step_skip: int, callback_params): + def __init__(self, step_skip: int, callback_params: dict) -> None: """ Parameters @@ -73,7 +76,7 @@ def __init__(self, step_skip: int, callback_params): self.sample_every = step_skip self.callback_params = callback_params - def make_callback(self, system, time, current_step: int): + def make_callback(self, system: SystemType, time: Float, current_step: int) -> None: if current_step % self.sample_every == 0: @@ -116,8 +119,8 @@ def __init__( directory: str, method: str, initial_file_count: int = 0, - file_save_interval: int = 1e8, - ): + file_save_interval: int = 100_000_000, + ) -> None: """ Parameters ---------- @@ -189,7 +192,7 @@ def __init__( self._pickle = pickle self._ext = "pkl" - def make_callback(self, system, time, current_step: int): + def make_callback(self, system: SystemType, time: Float, current_step: int) -> None: """ Parameters @@ -224,7 +227,7 @@ def make_callback(self, system, time, current_step: int): ): self._dump() - def _dump(self, **kwargs): + def _dump(self, **kwargs: Any) -> None: """ Dump dictionary buffer (self.buffer) to a file and clear the buffer. @@ -247,7 +250,7 @@ def _dump(self, **kwargs): self.buffer_size = 0 self.buffer.clear() - def get_last_saved_path(self) -> str: + def get_last_saved_path(self) -> Optional[str]: """ Return last saved file path. If no file has been saved, return None @@ -257,14 +260,14 @@ def get_last_saved_path(self) -> str: else: return self.save_path.format(self.file_count - 1, self._ext) - def close(self): + def close(self) -> None: """ Save residual buffer """ if self.buffer_size: self._dump() - def clear(self): + def clear(self) -> None: """ Alias to `close` """ diff --git a/elastica/contact_forces.py b/elastica/contact_forces.py index 8f9b0ab5e..9a2548f31 100644 --- a/elastica/contact_forces.py +++ b/elastica/contact_forces.py @@ -1,6 +1,6 @@ __doc__ = """ Numba implementation module containing contact between rods and rigid bodies and other rods rigid bodies or surfaces.""" -from elastica.typing import RodType, SystemType, AllowedContactType +from elastica.typing import Float, RodType, SystemType, AllowedContactType from elastica.rod import RodBase from elastica.rigidbody import Cylinder, Sphere from elastica.surface import Plane @@ -19,6 +19,7 @@ _calculate_contact_forces_cylinder_plane, ) import numpy as np +from numpy.typing import NDArray class NoContact: @@ -32,7 +33,7 @@ class NoContact: """ - def __init__(self): + def __init__(self) -> None: """ NoContact class does not need any input parameters. """ @@ -101,7 +102,7 @@ class RodRodContact(NoContact): """ - def __init__(self, k: float, nu: float): + def __init__(self, k: Float, nu: Float) -> None: """ Parameters ---------- @@ -225,11 +226,11 @@ class RodCylinderContact(NoContact): def __init__( self, - k: float, - nu: float, + k: Float, + nu: Float, velocity_damping_coefficient=0.0, friction_coefficient=0.0, - ): + ) -> None: """ Parameters @@ -338,7 +339,7 @@ class RodSelfContact(NoContact): """ - def __init__(self, k: float, nu: float): + def __init__(self, k: Float, nu: Float) -> None: """ Parameters @@ -435,11 +436,11 @@ class RodSphereContact(NoContact): def __init__( self, - k: float, - nu: float, - velocity_damping_coefficient=0.0, - friction_coefficient=0.0, - ): + k: Float, + nu: Float, + velocity_damping_coefficient: Float = 0.0, + friction_coefficient: Float = 0.0, + ) -> None: """ Parameters ---------- @@ -560,9 +561,9 @@ class RodPlaneContact(NoContact): def __init__( self, - k: float, - nu: float, - ): + k: Float, + nu: Float, + ) -> None: """ Parameters ---------- @@ -652,12 +653,12 @@ class RodPlaneContactWithAnisotropicFriction(NoContact): def __init__( self, - k: float, - nu: float, - slip_velocity_tol: float, - static_mu_array: np.ndarray, - kinetic_mu_array: np.ndarray, - ): + k: Float, + nu: Float, + slip_velocity_tol: Float, + static_mu_array: NDArray[np.floating], + kinetic_mu_array: NDArray[np.floating], + ) -> None: """ Parameters ---------- @@ -776,9 +777,9 @@ class CylinderPlaneContact(NoContact): def __init__( self, - k: float, - nu: float, - ): + k: Float, + nu: Float, + ) -> None: """ Parameters ---------- diff --git a/elastica/contact_utils.py b/elastica/contact_utils.py index 71584d2bb..d91921681 100644 --- a/elastica/contact_utils.py +++ b/elastica/contact_utils.py @@ -3,37 +3,49 @@ from math import sqrt import numba import numpy as np +from numpy.typing import NDArray from elastica._linalg import ( _batch_norm, ) +from typing import Literal, Sequence, TypeVar + +from elastica.typing import Float @numba.njit(cache=True) -def _dot_product(a, b): - sum = 0.0 +def _dot_product(a: Sequence[Float], b: Sequence[Float]) -> Float: + sum: Float = 0.0 for i in range(3): sum += a[i] * b[i] return sum @numba.njit(cache=True) -def _norm(a): +def _norm(a: Sequence[Float]) -> float: return sqrt(_dot_product(a, a)) +_SupportsCompareT = TypeVar("_SupportsCompareT") + + @numba.njit(cache=True) -def _clip(x, low, high): +def _clip(x: Float, low: Float, high: Float) -> Float: return max(low, min(x, high)) # Can this be made more efficient than 2 comp, 1 or? @numba.njit(cache=True) -def _out_of_bounds(x, low, high): +def _out_of_bounds(x: Float, low: Float, high: Float) -> Float: return (x < low) or (x > high) @numba.njit(cache=True) -def _find_min_dist(x1, e1, x2, e2): +def _find_min_dist( + x1: NDArray[np.floating], + e1: NDArray[np.floating], + x2: NDArray[np.floating], + e2: NDArray[np.floating], +) -> tuple[NDArray[np.floating], NDArray[np.floating], NDArray[np.floating]]: e1e1 = _dot_product(e1, e1) e1e2 = _dot_product(e1, e2) e2e2 = _dot_product(e2, e2) @@ -99,7 +111,9 @@ def _find_min_dist(x1, e1, x2, e2): @numba.njit(cache=True) -def _aabbs_not_intersecting(aabb_one, aabb_two): +def _aabbs_not_intersecting( + aabb_one: NDArray[np.floating], aabb_two: NDArray[np.floating] +) -> Literal[1, 0]: """Returns true if not intersecting else false""" if (aabb_one[0, 1] < aabb_two[0, 0]) | (aabb_one[0, 0] > aabb_two[0, 1]): return 1 @@ -120,7 +134,7 @@ def _prune_using_aabbs_rod_cylinder( cylinder_director, cylinder_radius, cylinder_length, -): +) -> Literal[1, 0]: max_possible_dimension = np.zeros((3,)) aabb_rod = np.empty((3, 2)) aabb_cylinder = np.empty((3, 2)) @@ -161,7 +175,7 @@ def _prune_using_aabbs_rod_rod( rod_two_position_collection, rod_two_radius_collection, rod_two_length_collection, -): +) -> Literal[1, 0]: max_possible_dimension = np.zeros((3,)) aabb_rod_one = np.empty((3, 2)) aabb_rod_two = np.empty((3, 2)) @@ -199,7 +213,7 @@ def _prune_using_aabbs_rod_sphere( sphere_position, sphere_director, sphere_radius, -): +) -> Literal[1, 0]: max_possible_dimension = np.zeros((3,)) aabb_rod = np.empty((3, 2)) aabb_sphere = np.empty((3, 2)) @@ -231,7 +245,9 @@ def _prune_using_aabbs_rod_sphere( @numba.njit(cache=True) -def _find_slipping_elements(velocity_slip, velocity_threshold): +def _find_slipping_elements( + velocity_slip: NDArray[np.floating], velocity_threshold: Float +) -> NDArray[np.floating]: """ This function takes the velocity of elements and checks if they are larger than the threshold velocity. If the velocity of elements is larger than threshold velocity, that means those elements are slipping. @@ -272,7 +288,7 @@ def _find_slipping_elements(velocity_slip, velocity_threshold): @numba.njit(cache=True) -def _node_to_element_mass_or_force(input): +def _node_to_element_mass_or_force(input: NDArray[np.floating]) -> NDArray[np.floating]: """ This function converts the mass/forces on rod nodes to elements, where special treatment is necessary at the ends. @@ -310,7 +326,10 @@ def _node_to_element_mass_or_force(input): @numba.njit(cache=True) -def _elements_to_nodes_inplace(vector_in_element_frame, vector_in_node_frame): +def _elements_to_nodes_inplace( + vector_in_element_frame: NDArray[np.floating], + vector_in_node_frame: NDArray[np.floating], +) -> None: """ Updating nodal forces using the forces computed on elements Parameters @@ -333,7 +352,9 @@ def _elements_to_nodes_inplace(vector_in_element_frame, vector_in_node_frame): @numba.njit(cache=True) -def _node_to_element_position(node_position_collection): +def _node_to_element_position( + node_position_collection: NDArray[np.floating], +) -> NDArray[np.floating]: """ This function computes the position of the elements from the nodal values. @@ -379,7 +400,9 @@ def _node_to_element_position(node_position_collection): @numba.njit(cache=True) -def _node_to_element_velocity(mass, node_velocity_collection): +def _node_to_element_velocity( + mass: NDArray[np.floating], node_velocity_collection: NDArray[np.floating] +) -> NDArray[np.floating]: """ This function computes the velocity of the elements from the nodal values. Uses the velocity of center of mass diff --git a/elastica/dissipation.py b/elastica/dissipation.py index f3629fe76..106b72322 100644 --- a/elastica/dissipation.py +++ b/elastica/dissipation.py @@ -5,12 +5,14 @@ """ from abc import ABC, abstractmethod +from typing import Any -from elastica.typing import RodType, SystemType +from elastica.typing import Float, RodType, SystemType from numba import njit import numpy as np +from numpy.typing import NDArray class DamperBase(ABC): @@ -29,7 +31,8 @@ class DamperBase(ABC): _system: SystemType - def __init__(self, *args, **kwargs): + # TODO typing can be made better + def __init__(self, *args: Any, **kwargs: Any) -> None: """Initialize damping module""" try: self._system = kwargs["_system"] @@ -40,7 +43,7 @@ def __init__(self, *args, **kwargs): ) @property - def system(self): # -> SystemType: (Return type is not parsed with sphinx book.) + def system(self) -> SystemType: """ get system (rod or rigid body) reference @@ -52,7 +55,7 @@ def system(self): # -> SystemType: (Return type is not parsed with sphinx book. return self._system @abstractmethod - def dampen_rates(self, system: SystemType, time: float): + def dampen_rates(self, system: SystemType, time: Float) -> None: # TODO: In the future, we can remove rod and use self.system """ Dampen rates (velocity and/or omega) of a rod object. @@ -113,7 +116,9 @@ class AnalyticalLinearDamper(DamperBase): Damping coefficient acting on rotational velocity. """ - def __init__(self, damping_constant, time_step, **kwargs): + def __init__( + self, damping_constant: Float, time_step: Float, **kwargs: Any + ) -> None: """ Analytical linear damper initializer @@ -143,7 +148,7 @@ def __init__(self, damping_constant, time_step, **kwargs): * np.diagonal(self._system.inv_mass_second_moment_of_inertia).T ) - def dampen_rates(self, rod: RodType, time: float): + def dampen_rates(self, rod: RodType, time: Float) -> None: rod.velocity_collection[:] = ( rod.velocity_collection * self.translational_damping_coefficient ) @@ -202,7 +207,7 @@ class LaplaceDissipationFilter(DamperBase): Filter term that modifies rod rotational velocity. """ - def __init__(self, filter_order: int, **kwargs): + def __init__(self, filter_order: int, **kwargs: Any) -> None: """ Filter damper initializer @@ -232,7 +237,7 @@ def __init__(self, filter_order: int, **kwargs): self.omega_filter_term = np.zeros_like(self._system.omega_collection) self.filter_function = _filter_function_periodic_condition - def dampen_rates(self, rod: RodType, time: float) -> None: + def dampen_rates(self, rod: RodType, time: Float) -> None: self.filter_function( rod.velocity_collection, @@ -303,7 +308,9 @@ def _filter_function_periodic_condition( @njit(cache=True) def nb_filter_rate( - rate_collection: np.ndarray, filter_term: np.ndarray, filter_order: int + rate_collection: NDArray[np.floating], + filter_term: NDArray[np.floating], + filter_order: int, ) -> None: """ Filters the rod rates (velocities) in numba njit decorator diff --git a/elastica/external_forces.py b/elastica/external_forces.py index cb9c61e6c..63e7b6124 100644 --- a/elastica/external_forces.py +++ b/elastica/external_forces.py @@ -3,8 +3,10 @@ import numpy as np +from numpy.typing import NDArray + from elastica._linalg import _batch_matvec -from elastica.typing import SystemType, RodType +from elastica.typing import Float, SystemType, RodType from elastica.utils import _bspline from numba import njit @@ -22,13 +24,13 @@ class NoForces: """ - def __init__(self): + def __init__(self) -> None: """ NoForces class does not need any input parameters. """ pass - def apply_forces(self, system: SystemType, time: np.float64 = 0.0): + def apply_forces(self, system: SystemType, time: Float = 0.0) -> None: """Apply forces to a rod-like object. In NoForces class, this routine simply passes. @@ -43,7 +45,7 @@ def apply_forces(self, system: SystemType, time: np.float64 = 0.0): """ pass - def apply_torques(self, system: SystemType, time: np.float64 = 0.0): + def apply_torques(self, system: SystemType, time: Float = 0.0): """Apply torques to a rod-like object. In NoForces class, this routine simply passes. @@ -70,7 +72,9 @@ class GravityForces(NoForces): """ - def __init__(self, acc_gravity=np.array([0.0, -9.80665, 0.0])): + def __init__( + self, acc_gravity: NDArray[np.floating] = np.array([0.0, -9.80665, 0.0]) + ) -> None: """ Parameters @@ -82,14 +86,18 @@ def __init__(self, acc_gravity=np.array([0.0, -9.80665, 0.0])): super(GravityForces, self).__init__() self.acc_gravity = acc_gravity - def apply_forces(self, system: SystemType, time=0.0): + def apply_forces(self, system: SystemType, time: Float = 0.0) -> None: self.compute_gravity_forces( self.acc_gravity, system.mass, system.external_forces ) @staticmethod @njit(cache=True) - def compute_gravity_forces(acc_gravity, mass, external_forces): + def compute_gravity_forces( + acc_gravity: NDArray[np.floating], + mass: NDArray[np.floating], + external_forces: NDArray[np.floating], + ) -> None: """ This function add gravitational forces on the nodes. We are using njit decorated function to increase the speed. @@ -122,7 +130,12 @@ class EndpointForces(NoForces): """ - def __init__(self, start_force, end_force, ramp_up_time): + def __init__( + self, + start_force: NDArray[np.floating], + end_force: NDArray[np.floating], + ramp_up_time: Float, + ) -> None: """ Parameters @@ -143,7 +156,7 @@ def __init__(self, start_force, end_force, ramp_up_time): assert ramp_up_time > 0.0 self.ramp_up_time = ramp_up_time - def apply_forces(self, system: SystemType, time=0.0): + def apply_forces(self, system: SystemType, time: Float = 0.0) -> None: self.compute_end_point_forces( system.external_forces, self.start_force, @@ -155,8 +168,12 @@ def apply_forces(self, system: SystemType, time=0.0): @staticmethod @njit(cache=True) def compute_end_point_forces( - external_forces, start_force, end_force, time, ramp_up_time - ): + external_forces: NDArray[np.floating], + start_force: NDArray[np.floating], + end_force: NDArray[np.floating], + time: Float, + ramp_up_time: Float, + ) -> None: """ Compute end point forces that are applied on the rod using numba njit decorator. @@ -190,7 +207,9 @@ class UniformTorques(NoForces): """ - def __init__(self, torque, direction=np.array([0.0, 0.0, 0.0])): + def __init__( + self, torque: Float, direction: NDArray[np.floating] = np.array([0.0, 0.0, 0.0]) + ) -> None: """ Parameters @@ -204,7 +223,7 @@ def __init__(self, torque, direction=np.array([0.0, 0.0, 0.0])): super(UniformTorques, self).__init__() self.torque = torque * direction - def apply_torques(self, system: SystemType, time: np.float64 = 0.0): + def apply_torques(self, system: SystemType, time: Float = 0.0) -> None: n_elems = system.n_elems torque_on_one_element = ( _batch_product_i_k_to_ik(self.torque, np.ones((n_elems))) / n_elems @@ -224,7 +243,9 @@ class UniformForces(NoForces): 2D (dim, 1) array containing data with 'float' type. Total force applied to a rod-like object. """ - def __init__(self, force, direction=np.array([0.0, 0.0, 0.0])): + def __init__( + self, force: Float, direction: NDArray[np.floating] = np.array([0.0, 0.0, 0.0]) + ) -> None: """ Parameters @@ -238,7 +259,7 @@ def __init__(self, force, direction=np.array([0.0, 0.0, 0.0])): super(UniformForces, self).__init__() self.force = (force * direction).reshape(3, 1) - def apply_forces(self, rod: RodType, time: np.float64 = 0.0): + def apply_forces(self, rod: RodType, time: Float = 0.0) -> None: force_on_one_element = self.force / rod.n_elems rod.external_forces += force_on_one_element @@ -284,7 +305,7 @@ def __init__( rest_lengths, ramp_up_time, with_spline=False, - ): + ) -> None: """ Parameters @@ -335,7 +356,7 @@ def __init__( else: self.my_spline = np.full_like(self.s, fill_value=1.0) - def apply_torques(self, rod: RodType, time: np.float64 = 0.0): + def apply_torques(self, rod: RodType, time: Float = 0.0) -> None: self.compute_muscle_torques( time, self.my_spline, @@ -388,7 +409,10 @@ def compute_muscle_torques( @njit(cache=True) -def inplace_addition(external_force_or_torque, force_or_torque): +def inplace_addition( + external_force_or_torque: NDArray[np.floating], + force_or_torque: NDArray[np.floating], +) -> None: """ This function does inplace addition. First argument `external_force_or_torque` is the system.external_forces @@ -411,7 +435,10 @@ def inplace_addition(external_force_or_torque, force_or_torque): @njit(cache=True) -def inplace_substraction(external_force_or_torque, force_or_torque): +def inplace_substraction( + external_force_or_torque: NDArray[np.floating], + force_or_torque: NDArray[np.floating], +) -> None: """ This function does inplace substraction. First argument `external_force_or_torque` is the system.external_forces @@ -460,12 +487,12 @@ class EndpointForcesSinusoidal(NoForces): def __init__( self, - start_force_mag, - end_force_mag, - ramp_up_time=0.0, - tangent_direction=np.array([0, 0, 1]), - normal_direction=np.array([0, 1, 0]), - ): + start_force_mag: Float, + end_force_mag: Float, + ramp_up_time: Float = 0.0, + tangent_direction: NDArray[np.floating] = np.array([0, 0, 1]), + normal_direction: NDArray[np.floating] = np.array([0, 1, 0]), + ) -> None: """ Parameters @@ -495,7 +522,7 @@ def __init__( assert ramp_up_time >= 0.0 self.ramp_up_time = ramp_up_time - def apply_forces(self, system: SystemType, time=0.0): + def apply_forces(self, system: SystemType, time: Float = 0.0) -> None: if time < self.ramp_up_time: # When time smaller than ramp up time apply the force in normal direction diff --git a/elastica/interaction.py b/elastica/interaction.py index 95d4602e1..7d26503cc 100644 --- a/elastica/interaction.py +++ b/elastica/interaction.py @@ -1,6 +1,7 @@ __doc__ = """ Numba implementation module containing interactions between a rod and its environment.""" +from typing import Any, NoReturn import numpy as np from elastica.external_forces import NoForces from numba import njit @@ -14,8 +15,12 @@ _calculate_contact_forces_cylinder_plane, ) +from numpy.typing import NDArray -def find_slipping_elements(velocity_slip, velocity_threshold): +from elastica.typing import Float, SystemType + + +def find_slipping_elements(velocity_slip: Any, velocity_threshold: Any) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. Please use\n" "elastica.contact_utils._find_slipping_elements()\n" @@ -23,7 +28,7 @@ def find_slipping_elements(velocity_slip, velocity_threshold): ) -def node_to_element_mass_or_force(input): +def node_to_element_mass_or_force(input: Any) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. Please use\n" "elastica.contact_utils._node_to_element_mass_or_force()\n" @@ -31,7 +36,7 @@ def node_to_element_mass_or_force(input): ) -def nodes_to_elements(input): +def nodes_to_elements(input: Any) -> NoReturn: # Remove the function beyond v0.4.0 raise NotImplementedError( "This function is removed in v0.3.1. Please use\n" @@ -41,7 +46,9 @@ def nodes_to_elements(input): @njit(cache=True) -def elements_to_nodes_inplace(vector_in_element_frame, vector_in_node_frame): +def elements_to_nodes_inplace( + vector_in_element_frame: Any, vector_in_node_frame: Any +) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. Please use\n" "elastica.contact_utils._elements_to_nodes_inplace()\n" @@ -74,7 +81,13 @@ class InteractionPlane: """ - def __init__(self, k, nu, plane_origin, plane_normal): + def __init__( + self, + k: Float, + nu: Float, + plane_origin: NDArray[np.floating], + plane_normal: NDArray[np.floating], + ) -> None: """ Parameters @@ -96,7 +109,7 @@ def __init__(self, k, nu, plane_origin, plane_normal): self.plane_normal = plane_normal.reshape(3) self.surface_tol = 1e-4 - def apply_normal_force(self, system): + def apply_normal_force(self, system: SystemType): """ In the case of contact with the plane, this function computes the plane reaction force on the element. @@ -130,18 +143,18 @@ def apply_normal_force(self, system): def apply_normal_force_numba( - plane_origin, - plane_normal, - surface_tol, - k, - nu, - radius, - mass, - position_collection, - velocity_collection, - internal_forces, - external_forces, -): + plane_origin: Any, + plane_normal: Any, + surface_tol: Any, + k: Any, + nu: Any, + radius: Any, + mass: Any, + position_collection: Any, + velocity_collection: Any, + internal_forces: Any, + external_forces: Any, +) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. For rod plane contact please use: \n" "elastica._contact_functions._calculate_contact_forces_rod_plane() \n" @@ -186,14 +199,14 @@ class AnisotropicFrictionalPlane(NoForces, InteractionPlane): def __init__( self, - k, - nu, - plane_origin, - plane_normal, - slip_velocity_tol, - static_mu_array, - kinetic_mu_array, - ): + k: Float, + nu: Float, + plane_origin: NDArray[np.floating], + plane_normal: NDArray[np.floating], + slip_velocity_tol: Float, + static_mu_array: NDArray[np.floating], + kinetic_mu_array: NDArray[np.floating], + ) -> None: """ Parameters @@ -232,7 +245,7 @@ def __init__( # kinetic and static friction should separate functions # for now putting them together to figure out common variables - def apply_forces(self, system, time=0.0): + def apply_forces(self, system: SystemType, time: Float = 0.0) -> None: """ Call numba implementation to apply friction forces Parameters @@ -269,30 +282,30 @@ def apply_forces(self, system, time=0.0): def anisotropic_friction( - plane_origin, - plane_normal, - surface_tol, - slip_velocity_tol, - k, - nu, - kinetic_mu_forward, - kinetic_mu_backward, - kinetic_mu_sideways, - static_mu_forward, - static_mu_backward, - static_mu_sideways, - radius, - mass, - tangents, - position_collection, - director_collection, - velocity_collection, - omega_collection, - internal_forces, - external_forces, - internal_torques, - external_torques, -): + plane_origin: Any, + plane_normal: Any, + surface_tol: Any, + slip_velocity_tol: Any, + k: Any, + nu: Any, + kinetic_mu_forward: Any, + kinetic_mu_backward: Any, + kinetic_mu_sideways: Any, + static_mu_forward: Any, + static_mu_backward: Any, + static_mu_sideways: Any, + radius: Any, + mass: Any, + tangents: Any, + position_collection: Any, + director_collection: Any, + velocity_collection: Any, + omega_collection: Any, + internal_forces: Any, + external_forces: Any, + internal_torques: Any, + external_torques: Any, +) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. For anisotropic_friction please use: \n" "elastica._contact_functions._calculate_contact_forces_rod_plane_with_anisotropic_friction() \n" @@ -302,7 +315,7 @@ def anisotropic_friction( # Slender body module @njit(cache=True) -def sum_over_elements(input): +def sum_over_elements(input: NDArray[np.floating]) -> Float: """ This function sums all elements of the input array. Using a Numba njit decorator shows better performance @@ -334,14 +347,14 @@ def sum_over_elements(input): This version: 513 ns ± 24.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each) """ - output = 0.0 + output: Float = 0.0 for i in range(input.shape[0]): output += input[i] return output -def node_to_element_position(node_position_collection): +def node_to_element_position(node_position_collection: Any) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. For node-to-element_position() interpolation please use: \n" "elastica.contact_utils._node_to_element_position() for rod position \n" @@ -349,7 +362,7 @@ def node_to_element_position(node_position_collection): ) -def node_to_element_velocity(mass, node_velocity_collection): +def node_to_element_velocity(mass, node_velocity_collection: Any) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. For node-to-element_velocity() interpolation please use: \n" "elastica.contact_utils._node_to_element_velocity() for rod velocity. \n" @@ -357,7 +370,7 @@ def node_to_element_velocity(mass, node_velocity_collection): ) -def node_to_element_pos_or_vel(vector_in_node_frame): +def node_to_element_pos_or_vel(vector_in_node_frame: Any) -> NoReturn: # Remove the function beyond v0.4.0 raise NotImplementedError( "This function is removed in v0.3.0. For node-to-element interpolation please use: \n" @@ -369,8 +382,13 @@ def node_to_element_pos_or_vel(vector_in_node_frame): @njit(cache=True) def slender_body_forces( - tangents, velocity_collection, dynamic_viscosity, lengths, radius, mass -): + tangents: NDArray[np.floating], + velocity_collection: NDArray[np.floating], + dynamic_viscosity: Float, + lengths: NDArray[np.floating], + radius: NDArray[np.floating], + mass: NDArray[np.floating], +) -> NDArray[np.floating]: r""" This function computes hydrodynamic forces on a body using slender body theory. The below implementation is from Eq. 4.13 in Gazzola et al. RSoS. (2018). @@ -481,7 +499,7 @@ class SlenderBodyTheory(NoForces): """ - def __init__(self, dynamic_viscosity): + def __init__(self, dynamic_viscosity: Float) -> None: """ Parameters @@ -492,7 +510,7 @@ def __init__(self, dynamic_viscosity): super(SlenderBodyTheory, self).__init__() self.dynamic_viscosity = dynamic_viscosity - def apply_forces(self, system, time=0.0): + def apply_forces(self, system: SystemType, time: Float = 0.0) -> None: """ This function applies hydrodynamic forces on body using the slender body theory given in @@ -518,14 +536,20 @@ def apply_forces(self, system, time=0.0): # base class for interaction # only applies normal force no friction class InteractionPlaneRigidBody: - def __init__(self, k, nu, plane_origin, plane_normal): + def __init__( + self, + k: Float, + nu: Float, + plane_origin: NDArray[np.floating], + plane_normal: NDArray[np.floating], + ) -> None: self.k = k self.nu = nu self.plane_origin = plane_origin.reshape(3, 1) self.plane_normal = plane_normal.reshape(3) self.surface_tol = 1e-4 - def apply_normal_force(self, system): + def apply_normal_force(self, system: SystemType): """ This function computes the plane force response on the rigid body, in the case of contact. Contact model given in Eqn 4.8 Gazzola et. al. RSoS 2018 paper @@ -553,16 +577,16 @@ def apply_normal_force(self, system): @njit(cache=True) def apply_normal_force_numba_rigid_body( - plane_origin, - plane_normal, - surface_tol, - k, - nu, - length, - position_collection, - velocity_collection, - external_forces, -): + plane_origin: Any, + plane_normal: Any, + surface_tol: Any, + k: Any, + nu: Any, + length: Any, + position_collection: Any, + velocity_collection: Any, + external_forces: Any, +) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. For cylinder plane contact please use: \n" diff --git a/elastica/joint.py b/elastica/joint.py index e37d6f058..9d45ba978 100644 --- a/elastica/joint.py +++ b/elastica/joint.py @@ -1,8 +1,12 @@ __doc__ = """ Module containing joint classes to connect multiple rods together. """ + +from typing import Any, NoReturn, Optional + from elastica._rotations import _inv_rotate -from elastica.typing import SystemType, RodType +from elastica.typing import SystemType, RodType, Float import numpy as np import logging +from numpy.typing import NDArray class FreeJoint: @@ -27,7 +31,7 @@ class FreeJoint: # pass the k and nu for the forces # also the necessary rods for the joint # indices should be 0 or -1, we will provide wrappers for users later - def __init__(self, k, nu): + def __init__(self, k: Float, nu: Float) -> None: """ Parameters @@ -42,8 +46,12 @@ def __init__(self, k, nu): self.nu = nu def apply_forces( - self, system_one: SystemType, index_one, system_two: SystemType, index_two - ): + self, + system_one: SystemType, + index_one: int, + system_two: SystemType, + index_two: int, + ) -> None: """ Apply joint force to the connected rod objects. @@ -81,8 +89,12 @@ def apply_forces( return def apply_torques( - self, system_one: SystemType, index_one, system_two: SystemType, index_two - ): + self, + system_one: SystemType, + index_one: int, + system_two: SystemType, + index_two: int, + ) -> None: """ Apply restoring joint torques to the connected rod objects. @@ -127,7 +139,9 @@ class HingeJoint(FreeJoint): """ # TODO: IN WRAPPER COMPUTE THE NORMAL DIRECTION OR ASK USER TO GIVE INPUT, IF NOT THROW ERROR - def __init__(self, k, nu, kt, normal_direction): + def __init__( + self, k: Float, nu: Float, kt: Float, normal_direction: NDArray[np.floating] + ) -> None: """ Parameters @@ -154,19 +168,19 @@ def __init__(self, k, nu, kt, normal_direction): def apply_forces( self, system_one: SystemType, - index_one, + index_one: int, system_two: SystemType, - index_two, - ): + index_two: int, + ) -> None: return super().apply_forces(system_one, index_one, system_two, index_two) def apply_torques( self, system_one: SystemType, - index_one, + index_one: int, system_two: SystemType, - index_two, - ): + index_two: int, + ) -> None: # current tangent direction of the `index_two` element of system two system_two_tangent = system_two.director_collection[2, :, index_two] @@ -215,7 +229,14 @@ class FixedJoint(FreeJoint): is enforced. """ - def __init__(self, k, nu, kt, nut=0.0, rest_rotation_matrix=None): + def __init__( + self, + k: Float, + nu: Float, + kt: Float, + nut: Float = 0.0, + rest_rotation_matrix: Optional[NDArray[np.floating]] = None, + ) -> None: """ Parameters @@ -254,19 +275,19 @@ def __init__(self, k, nu, kt, nut=0.0, rest_rotation_matrix=None): def apply_forces( self, system_one: SystemType, - index_one, + index_one: int, system_two: SystemType, - index_two, - ): + index_two: int, + ) -> None: return super().apply_forces(system_one, index_one, system_two, index_two) def apply_torques( self, system_one: SystemType, - index_one, + index_one: int, system_two: SystemType, - index_two, - ): + index_two: int, + ) -> None: # collect directors of systems one and two # note that systems can be either rods or rigid bodies system_one_director = system_one.director_collection[..., index_one] @@ -311,10 +332,10 @@ def apply_torques( def get_relative_rotation_two_systems( system_one: SystemType, - index_one, + index_one: int, system_two: SystemType, - index_two, -): + index_two: int, +) -> NDArray[np.floating]: """ Compute the relative rotation matrix C_12 between system one and system two at the specified elements. @@ -362,7 +383,7 @@ def get_relative_rotation_two_systems( # everything below this comment should be removed beyond v0.4.0 -def _dot_product(a, b): +def _dot_product(a: Any, b: Any) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. Please use\n" "elastica.contact_utils._dot_product()\n" @@ -370,7 +391,7 @@ def _dot_product(a, b): ) -def _norm(a): +def _norm(a: Any) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. Please use\n" "elastica.contact_utils._norm()\n" @@ -378,7 +399,7 @@ def _norm(a): ) -def _clip(x, low, high): +def _clip(x: Any, low: Any, high: Any) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. Please use\n" "elastica.contact_utils._clip()\n" @@ -386,7 +407,7 @@ def _clip(x, low, high): ) -def _out_of_bounds(x, low, high): +def _out_of_bounds(x: Any, low: Any, high: Any) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. Please use\n" "elastica.contact_utils._out_of_bounds()\n" @@ -394,7 +415,7 @@ def _out_of_bounds(x, low, high): ) -def _find_min_dist(x1, e1, x2, e2): +def _find_min_dist(x1: Any, e1: Any, x2: Any, e2: Any) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. Please use\n" "elastica.contact_utils._find_min_dist()\n" @@ -403,25 +424,25 @@ def _find_min_dist(x1, e1, x2, e2): def _calculate_contact_forces_rod_rigid_body( - x_collection_rod, - edge_collection_rod, - x_cylinder_center, - x_cylinder_tip, - edge_cylinder, - radii_sum, - length_sum, - internal_forces_rod, - external_forces_rod, - external_forces_cylinder, - external_torques_cylinder, - cylinder_director_collection, - velocity_rod, - velocity_cylinder, - contact_k, - contact_nu, - velocity_damping_coefficient, - friction_coefficient, -): + x_collection_rod: Any, + edge_collection_rod: Any, + x_cylinder_center: Any, + x_cylinder_tip: Any, + edge_cylinder: Any, + radii_sum: Any, + length_sum: Any, + internal_forces_rod: Any, + external_forces_rod: Any, + external_forces_cylinder: Any, + external_torques_cylinder: Any, + cylinder_director_collection: Any, + velocity_rod: Any, + velocity_cylinder: Any, + contact_k: Any, + contact_nu: Any, + velocity_damping_coefficient: Any, + friction_coefficient: Any, +) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. Please use\n" "elastica._contact_functions._calculate_contact_forces_rod_cylinder()\n" @@ -430,23 +451,23 @@ def _calculate_contact_forces_rod_rigid_body( def _calculate_contact_forces_rod_rod( - x_collection_rod_one, - radius_rod_one, - length_rod_one, - tangent_rod_one, - velocity_rod_one, - internal_forces_rod_one, - external_forces_rod_one, - x_collection_rod_two, - radius_rod_two, - length_rod_two, - tangent_rod_two, - velocity_rod_two, - internal_forces_rod_two, - external_forces_rod_two, - contact_k, - contact_nu, -): + x_collection_rod_one: Any, + radius_rod_one: Any, + length_rod_one: Any, + tangent_rod_one: Any, + velocity_rod_one: Any, + internal_forces_rod_one: Any, + external_forces_rod_one: Any, + x_collection_rod_two: Any, + radius_rod_two: Any, + length_rod_two: Any, + tangent_rod_two: Any, + velocity_rod_two: Any, + internal_forces_rod_two: Any, + external_forces_rod_two: Any, + contact_k: Any, + contact_nu: Any, +) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. Please use\n" "elastica._contact_functions._calculate_contact_forces_rod_rod()\n" @@ -455,15 +476,15 @@ def _calculate_contact_forces_rod_rod( def _calculate_contact_forces_self_rod( - x_collection_rod, - radius_rod, - length_rod, - tangent_rod, - velocity_rod, - external_forces_rod, - contact_k, - contact_nu, -): + x_collection_rod: Any, + radius_rod: Any, + length_rod: Any, + tangent_rod: Any, + velocity_rod: Any, + external_forces_rod: Any, + contact_k: Any, + contact_nu: Any, +) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. Please use\n" "elastica._contact_functions._calculate_contact_forces_self_rod()\n" @@ -471,7 +492,7 @@ def _calculate_contact_forces_self_rod( ) -def _aabbs_not_intersecting(aabb_one, aabb_two): +def _aabbs_not_intersecting(aabb_one: Any, aabb_two: Any) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. Please use\n" "elastica.contact_utils._aabbs_not_intersecting()\n" @@ -480,14 +501,14 @@ def _aabbs_not_intersecting(aabb_one, aabb_two): def _prune_using_aabbs_rod_rigid_body( - rod_one_position_collection, - rod_one_radius_collection, - rod_one_length_collection, - cylinder_position, - cylinder_director, - cylinder_radius, - cylinder_length, -): + rod_one_position_collection: Any, + rod_one_radius_collection: Any, + rod_one_length_collection: Any, + cylinder_position: Any, + cylinder_director: Any, + cylinder_radius: Any, + cylinder_length: Any, +) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. Please use\n" "elastica.contact_utils._prune_using_aabbs_rod_cylinder()\n" @@ -496,13 +517,13 @@ def _prune_using_aabbs_rod_rigid_body( def _prune_using_aabbs_rod_rod( - rod_one_position_collection, - rod_one_radius_collection, - rod_one_length_collection, - rod_two_position_collection, - rod_two_radius_collection, - rod_two_length_collection, -): + rod_one_position_collection: Any, + rod_one_radius_collection: Any, + rod_one_length_collection: Any, + rod_two_position_collection: Any, + rod_two_radius_collection: Any, + rod_two_length_collection: Any, +) -> NoReturn: raise NotImplementedError( "This function is removed in v0.3.2. Please use\n" "elastica.contact_utils._prune_using_aabbs_rod_rod()\n" @@ -555,7 +576,13 @@ class ExternalContact(FreeJoint): # potentially dangerous as it does not deal with "end" conditions # correctly. - def __init__(self, k, nu, velocity_damping_coefficient=0, friction_coefficient=0): + def __init__( + self, + k: Float, + nu: Float, + velocity_damping_coefficient: Float = 0, + friction_coefficient: Float = 0, + ) -> None: """ Parameters @@ -587,10 +614,10 @@ def __init__(self, k, nu, velocity_damping_coefficient=0, friction_coefficient=0 def apply_forces( self, rod_one: RodType, - index_one, + index_one: int, rod_two: SystemType, - index_two, - ): + index_two: int, + ) -> None: # del index_one, index_two from elastica.contact_utils import ( _prune_using_aabbs_rod_cylinder, @@ -693,7 +720,7 @@ class SelfContact(FreeJoint): """ - def __init__(self, k, nu): + def __init__(self, k: Float, nu: Float) -> None: super().__init__(k, nu) log = logging.getLogger(self.__class__.__name__) log.warning( @@ -705,7 +732,9 @@ def __init__(self, k, nu): "The option to use the SelfContact joint for the rod self contact will be removed in the future (v0.3.3).\n" ) - def apply_forces(self, rod_one: RodType, index_one, rod_two: SystemType, index_two): + def apply_forces( + self, rod_one: RodType, index_one: int, rod_two: SystemType, index_two: int + ) -> None: # del index_one, index_two from elastica._contact_functions import ( _calculate_contact_forces_self_rod, diff --git a/elastica/restart.py b/elastica/restart.py index 1b5fab267..dbc2ee935 100644 --- a/elastica/restart.py +++ b/elastica/restart.py @@ -5,8 +5,12 @@ from itertools import groupby from .memory_block import MemoryBlockCosseratRod, MemoryBlockRigidBody +from typing import Iterable, Iterator, Any -def all_equal(iterable): +from elastica.typing import Float + + +def all_equal(iterable: Iterable[Any]) -> bool: """ Checks if all elements of list are equal. Parameters @@ -20,11 +24,14 @@ def all_equal(iterable): ---------- https://stackoverflow.com/questions/3844801/check-if-all-elements-in-a-list-are-identical """ - g = groupby(iterable) + g: Iterator[Any] = groupby(iterable) return next(g, True) and not next(g, False) -def save_state(simulator, directory: str = "", time=0.0, verbose: bool = False): +# TODO: simulator should have better typing +def save_state( + simulator: Iterable, directory: str = "", time: Float = 0.0, verbose: bool = False +) -> None: """ Save state parameters of each rod. TODO : environment list variable is not uniform at the current stage of development. @@ -53,7 +60,10 @@ def save_state(simulator, directory: str = "", time=0.0, verbose: bool = False): print("Save complete: {}".format(directory)) -def load_state(simulator, directory: str = "", verbose: bool = False): +# TODO: simulator should have better typing +def load_state( + simulator: Iterable, directory: str = "", verbose: bool = False +) -> float: """ Load the rod-state. Compatibale with 'save_state' method. If the save-file does not exist, it returns error. @@ -72,7 +82,7 @@ def load_state(simulator, directory: str = "", verbose: bool = False): time : float Simulation time of systems when they are saved. """ - time_list = [] # Simulation time of rods when they are saved. + time_list: list[float] = [] # Simulation time of rods when they are saved. for idx, rod in enumerate(simulator): if isinstance(rod, MemoryBlockCosseratRod) or isinstance( rod, MemoryBlockRigidBody diff --git a/elastica/rod/cosserat_rod.py b/elastica/rod/cosserat_rod.py index a6c626389..5007b799f 100644 --- a/elastica/rod/cosserat_rod.py +++ b/elastica/rod/cosserat_rod.py @@ -2,6 +2,7 @@ import numpy as np +from numpy.typing import NDArray import functools import numba from elastica.rod import RodBase @@ -22,6 +23,8 @@ ) from typing import Optional +from elastica.typing import Float + position_difference_kernel = _difference position_average = _average @@ -147,39 +150,39 @@ class CosseratRod(RodBase, KnotTheory): def __init__( self, - n_elements, - position, - velocity, - omega, - acceleration, - angular_acceleration, - directors, - radius, - mass_second_moment_of_inertia, - inv_mass_second_moment_of_inertia, - shear_matrix, - bend_matrix, - density, - volume, - mass, - internal_forces, - internal_torques, - external_forces, - external_torques, - lengths, - rest_lengths, - tangents, - dilatation, - dilatation_rate, - voronoi_dilatation, - rest_voronoi_lengths, - sigma, - kappa, - rest_sigma, - rest_kappa, - internal_stress, - internal_couple, - ring_rod_flag, + n_elements: int, + position: NDArray[np.floating], + velocity: NDArray[np.floating], + omega: NDArray[np.floating], + acceleration: NDArray[np.floating], + angular_acceleration: NDArray[np.floating], + directors: NDArray[np.floating], + radius: NDArray[np.floating], + mass_second_moment_of_inertia: NDArray[np.floating], + inv_mass_second_moment_of_inertia: NDArray[np.floating], + shear_matrix: NDArray[np.floating], + bend_matrix: NDArray[np.floating], + density: NDArray[np.floating], + volume: NDArray[np.floating], + mass: NDArray[np.floating], + internal_forces: NDArray[np.floating], + internal_torques: NDArray[np.floating], + external_forces: NDArray[np.floating], + external_torques: NDArray[np.floating], + lengths: NDArray[np.floating], + rest_lengths: NDArray[np.floating], + tangents: NDArray[np.floating], + dilatation: NDArray[np.floating], + dilatation_rate: NDArray[np.floating], + voronoi_dilatation: NDArray[np.floating], + rest_voronoi_lengths: NDArray[np.floating], + sigma: NDArray[np.floating], + kappa: NDArray[np.floating], + rest_sigma: NDArray[np.floating], + rest_kappa: NDArray[np.floating], + internal_stress: NDArray[np.floating], + internal_couple: NDArray[np.floating], + ring_rod_flag: bool, ): self.n_elems = n_elements self.position_collection = position @@ -242,9 +245,9 @@ def __init__( def straight_rod( cls, n_elements: int, - start: np.ndarray, - direction: np.ndarray, - normal: np.ndarray, + start: NDArray[np.floating], + direction: NDArray[np.floating], + normal: NDArray[np.floating], base_length: float, base_radius: float, density: float, @@ -390,9 +393,9 @@ def straight_rod( def ring_rod( cls, n_elements: int, - ring_center_position: np.ndarray, - direction: np.ndarray, - normal: np.ndarray, + ring_center_position: NDArray[np.floating], + direction: NDArray[np.floating], + normal: NDArray[np.floating], base_length: float, base_radius: float, density: float, @@ -533,7 +536,7 @@ def ring_rod( ring_rod_flag, ) - def compute_internal_forces_and_torques(self, time): + def compute_internal_forces_and_torques(self, time: float): """ Compute internal forces and torques. We need to compute internal forces and torques before the acceleration because they are used in interaction. Thus in order to speed up simulation, we will compute internal forces and torques @@ -588,7 +591,7 @@ def compute_internal_forces_and_torques(self, time): ) # Interface to time-stepper mixins (Symplectic, Explicit), which calls this method - def update_accelerations(self, time): + def update_accelerations(self, time: float): """ Updates the acceleration variables @@ -610,12 +613,12 @@ def update_accelerations(self, time): self.dilatation, ) - def zeroed_out_external_forces_and_torques(self, time): + def zeroed_out_external_forces_and_torques(self, time: Float) -> None: _zeroed_out_external_forces_and_torques( self.external_forces, self.external_torques ) - def compute_translational_energy(self): + def compute_translational_energy(self) -> NDArray[np.floating]: """ Compute total translational energy of the rod at the instance. """ @@ -629,7 +632,7 @@ def compute_translational_energy(self): ).sum() ) - def compute_rotational_energy(self): + def compute_rotational_energy(self) -> NDArray[np.floating]: """ Compute total rotational energy of the rod at the instance. """ @@ -639,7 +642,7 @@ def compute_rotational_energy(self): ) return 0.5 * np.einsum("ik,ik->k", self.omega_collection, J_omega_upon_e).sum() - def compute_velocity_center_of_mass(self): + def compute_velocity_center_of_mass(self) -> NDArray[np.floating]: """ Compute velocity center of mass of the rod at the instance. """ @@ -648,7 +651,7 @@ def compute_velocity_center_of_mass(self): return sum_mass_times_velocity / self.mass.sum() - def compute_position_center_of_mass(self): + def compute_position_center_of_mass(self) -> NDArray[np.floating]: """ Compute position center of mass of the rod at the instance. """ @@ -657,7 +660,7 @@ def compute_position_center_of_mass(self): return sum_mass_times_position / self.mass.sum() - def compute_bending_energy(self): + def compute_bending_energy(self) -> NDArray[np.floating]: """ Compute total bending energy of the rod at the instance. """ @@ -673,7 +676,7 @@ def compute_bending_energy(self): ).sum() ) - def compute_shear_energy(self): + def compute_shear_energy(self) -> NDArray[np.floating]: """ Compute total shear energy of the rod at the instance. """ diff --git a/elastica/rod/data_structures.py b/elastica/rod/data_structures.py index 3c8b40328..6a1121039 100644 --- a/elastica/rod/data_structures.py +++ b/elastica/rod/data_structures.py @@ -9,7 +9,7 @@ # FIXME : Explicit Stepper doesn't work as States lose the # views they initially had when working with a timestepper. # class _RodExplicitStepperMixin: -# def __init__(self): +# def __init__(self) -> None: # ( # self.state, # self.__deriv_state, @@ -43,7 +43,7 @@ class _RodSymplecticStepperMixin: - def __init__(self): + def __init__(self) -> None: self.kinematic_states = _KinematicState( self.position_collection, self.director_collection ) @@ -296,7 +296,7 @@ class _DerivativeState: /multiplication used. """ - def __init__(self, _unused_n_elems: int, rate_collection_view): + def __init__(self, _unused_n_elems: int, rate_collection_view) -> None: """ Parameters ---------- @@ -388,7 +388,7 @@ class _KinematicState: only these methods are provided. """ - def __init__(self, position_collection_view, director_collection_view): + def __init__(self, position_collection_view, director_collection_view) -> None: """ Parameters ---------- diff --git a/elastica/rod/factory_function.py b/elastica/rod/factory_function.py index 6eca6b5ac..fa2ad8f19 100644 --- a/elastica/rod/factory_function.py +++ b/elastica/rod/factory_function.py @@ -3,22 +3,24 @@ import logging import numpy as np from numpy.testing import assert_allclose +from numpy.typing import NDArray +from elastica.typing import Float from elastica.utils import MaxDimension, Tolerance from elastica._linalg import _batch_cross, _batch_norm, _batch_dot def allocate( - n_elements, - direction, - normal, - base_length, - base_radius, - density, - youngs_modulus: float, + n_elements: int, + direction: NDArray[np.floating], + normal: NDArray[np.floating], + base_length: Float, + base_radius: Float, + density: Float, + youngs_modulus: Float, *, rod_origin_position: np.ndarray, ring_rod_flag: bool, - shear_modulus: Optional[float] = None, + shear_modulus: Optional[Float] = None, position: Optional[np.ndarray] = None, directors: Optional[np.ndarray] = None, rest_sigma: Optional[np.ndarray] = None, @@ -335,14 +337,14 @@ def allocate( """ -def _assert_dim(vector, max_dim: int, name: str): +def _assert_dim(vector: np.ndarray, max_dim: int, name: str): assert vector.ndim < max_dim, ( f"Input {name} dimension is not correct {vector.shape}" + f" It should be maximum {max_dim}D vector or single floating number." ) -def _assert_shape(array: np.ndarray, expected_shape: Tuple[int], name: str): +def _assert_shape(array: np.ndarray, expected_shape: Tuple[int, ...], name: str): assert array.shape == expected_shape, ( f"Given {name} shape is not correct, it should be " + str(expected_shape) @@ -351,7 +353,9 @@ def _assert_shape(array: np.ndarray, expected_shape: Tuple[int], name: str): ) -def _position_validity_checker(position, start, n_elements): +def _position_validity_checker( + position: NDArray[np.floating], start: NDArray[np.floating], n_elements: int +): """Checker on user-defined position validity""" _assert_shape(position, (MaxDimension.value(), n_elements + 1), "position") @@ -367,7 +371,9 @@ def _position_validity_checker(position, start, n_elements): ) -def _directors_validity_checker(directors, tangents, n_elements): +def _directors_validity_checker( + directors: NDArray[np.floating], tangents: NDArray[np.floating], n_elements: int +): """Checker on user-defined directors validity""" _assert_shape( directors, (MaxDimension.value(), MaxDimension.value(), n_elements), "directors" @@ -413,7 +419,11 @@ def _directors_validity_checker(directors, tangents, n_elements): ) -def _position_validity_checker_ring_rod(position, ring_center_position, n_elements): +def _position_validity_checker_ring_rod( + position: NDArray[np.floating], + ring_center_position: NDArray[np.floating], + n_elements: int, +): """Checker on user-defined position validity""" _assert_shape(position, (MaxDimension.value(), n_elements), "position") diff --git a/elastica/rod/knot_theory.py b/elastica/rod/knot_theory.py index 23d2b009f..a936dbda5 100644 --- a/elastica/rod/knot_theory.py +++ b/elastica/rod/knot_theory.py @@ -14,9 +14,11 @@ from numba import njit import numpy as np +from numpy.typing import NDArray from elastica.rod.rod_base import RodBase from elastica._linalg import _batch_norm, _batch_dot, _batch_cross +from elastica.typing import Float class KnotTheoryCompatibleProtocol(Protocol): @@ -46,7 +48,7 @@ class KnotTheory: KnotTheory can be mixed with any rod-class based on RodBase:: class MyRod(RodBase, KnotTheory): - def __init__(self): + def __init__(self) -> None: super().__init__() rod = MyRod(...) @@ -91,7 +93,7 @@ def compute_twist(self: MIXIN_PROTOCOL): def compute_writhe( self: MIXIN_PROTOCOL, type_of_additional_segment: str = "next_tangent", - alpha: float = 1.0, + alpha: Float = 1.0, ): """ See :ref:`api/rods:Knot Theory (Mixin)` for the detail. @@ -114,7 +116,7 @@ def compute_writhe( def compute_link( self: MIXIN_PROTOCOL, type_of_additional_segment: str = "next_tangent", - alpha: float = 1.0, + alpha: Float = 1.0, ): """ See :ref:`api/rods:Knot Theory (Mixin)` for the detail. @@ -138,7 +140,9 @@ def compute_link( )[0] -def compute_twist(center_line, normal_collection): +def compute_twist( + center_line: NDArray[np.floating], normal_collection: NDArray[np.floating] +): """ Compute the twist of a rod, using center_line and normal collection. @@ -189,7 +193,9 @@ def compute_twist(center_line, normal_collection): @njit(cache=True) -def _compute_twist(center_line, normal_collection): +def _compute_twist( + center_line: NDArray[np.floating], normal_collection: NDArray[np.floating] +): """ Parameters ---------- @@ -264,7 +270,11 @@ def _compute_twist(center_line, normal_collection): return total_twist, local_twist -def compute_writhe(center_line, segment_length, type_of_additional_segment): +def compute_writhe( + center_line: NDArray[np.floating], + segment_length: Float, + type_of_additional_segment: str, +): """ This function computes the total writhe history of a rod. @@ -314,7 +324,7 @@ def compute_writhe(center_line, segment_length, type_of_additional_segment): @njit(cache=True) -def _compute_writhe(center_line): +def _compute_writhe(center_line: NDArray[np.floating]): """ Parameters ---------- @@ -386,10 +396,10 @@ def _compute_writhe(center_line): def compute_link( - center_line: np.ndarray, - normal_collection: np.ndarray, - radius: np.ndarray, - segment_length: float, + center_line: NDArray[np.floating], + normal_collection: NDArray[np.floating], + radius: NDArray[np.floating], + segment_length: Float, type_of_additional_segment: str, ): """ @@ -470,7 +480,11 @@ def compute_link( @njit(cache=True) -def _compute_auxiliary_line(center_line, normal_collection, radius): +def _compute_auxiliary_line( + center_line: NDArray[np.floating], + normal_collection: NDArray[np.floating], + radius: NDArray[np.floating], +): """ This function computes the auxiliary line using rod center line and normal collection. @@ -525,7 +539,9 @@ def _compute_auxiliary_line(center_line, normal_collection, radius): @njit(cache=True) -def _compute_link(center_line, auxiliary_line): +def _compute_link( + center_line: NDArray[np.floating], auxiliary_line: NDArray[np.floating] +): """ Parameters @@ -604,7 +620,10 @@ def _compute_link(center_line, auxiliary_line): @njit(cache=True) def _compute_auxiliary_line_added_segments( - beginning_direction, end_direction, auxiliary_line, segment_length + beginning_direction: NDArray[np.floating], + end_direction: NDArray[np.floating], + auxiliary_line: NDArray[np.floating], + segment_length: Float, ): """ This code is for computing position of added segments to the auxiliary line. @@ -647,7 +666,9 @@ def _compute_auxiliary_line_added_segments( @njit(cache=True) def _compute_additional_segment( - center_line, segment_length, type_of_additional_segment + center_line: NDArray[np.floating], + segment_length: Float, + type_of_additional_segment: str, ): """ This function adds two points at the end of center line. Distance from the center line is given by segment_length. diff --git a/elastica/rod/rod_base.py b/elastica/rod/rod_base.py index 15785a360..28e20b708 100644 --- a/elastica/rod/rod_base.py +++ b/elastica/rod/rod_base.py @@ -1,5 +1,7 @@ __doc__ = """Base class for rods""" +import numpy as np +from numpy.typing import NDArray class RodBase: """ @@ -11,8 +13,13 @@ class RodBase: """ - def __init__(self): + def __init__(self) -> None: """ RodBase does not take any arguments. """ - pass + self.position_collection: NDArray[np.floating] + self.omega_collection: NDArray[np.floating] + self.acceleration_collection: NDArray[np.floating] + self.alpha_collection: NDArray[np.floating] + self.external_forces: NDArray[np.floating] + self.external_torques: NDArray[np.floating] diff --git a/elastica/transformations.py b/elastica/transformations.py index b002e9569..e19c78b97 100644 --- a/elastica/transformations.py +++ b/elastica/transformations.py @@ -7,14 +7,18 @@ _skew_symmetrize, _rotate, ) +from elastica.typing import Float from .utils import MaxDimension, isqrt +from numpy.typing import NDArray # TODO Complete, but nicer interface, evolve it eventually -def format_vector_shape(vector_collection): +def format_vector_shape( + vector_collection: NDArray[np.floating], +) -> NDArray[np.floating]: """ Function for formatting vector shapes into correct format @@ -59,7 +63,9 @@ def format_vector_shape(vector_collection): return vector_collection -def format_matrix_shape(matrix_collection): +def format_matrix_shape( + matrix_collection: NDArray[np.floating], +) -> NDArray[np.floating]: """ Formats input matrix into correct format @@ -77,7 +83,7 @@ def format_matrix_shape(matrix_collection): # check first two dimensions are same and matrix is square # other possibility is one dimension is dim**2 and other is blocksize, # we need to convert the matrix in that case. - def assert_proper_square(num1): + def assert_proper_square(num1: int) -> int: sqrt_num = isqrt(num1) assert sqrt_num**2 == num1, "Matrix dimension passed is not a perfect square" return sqrt_num @@ -136,12 +142,14 @@ def assert_proper_square(num1): return matrix_collection -def skew_symmetrize(vector): +def skew_symmetrize(vector: NDArray[np.floating]) -> NDArray[np.floating]: vector = format_vector_shape(vector) return _skew_symmetrize(vector) -def inv_skew_symmetrize(matrix_collection): +def inv_skew_symmetrize( + matrix_collection: NDArray[np.floating], +) -> NDArray[np.floating]: """ Safe wrapper around inv_skew_symmetrize that does checking and formatting on type of matrix_collection using format_matrix_shape @@ -167,7 +175,9 @@ def inv_skew_symmetrize(matrix_collection): raise ValueError("matrix_collection passed is not skew-symmetric") -def rotate(matrix, scale, axis): +def rotate( + matrix: NDArray[np.floating], scale: Float, axis: NDArray[np.floating] +) -> NDArray[np.floating]: """ This function takes single or multiple frames as matrix. Then rotates these frames around a single axis for all frames, or can rotate each frame around its own diff --git a/elastica/typing.py b/elastica/typing.py index 7c3df8702..34f55d1e2 100644 --- a/elastica/typing.py +++ b/elastica/typing.py @@ -7,6 +7,7 @@ from typing import Type, Union, Callable, Any from typing import TypeAlias +import numpy as np if TYPE_CHECKING: # Used for type hinting without circular imports @@ -53,7 +54,7 @@ # [SymplecticStepperProtocol, np.floating], np.floating # ] OperatorType: TypeAlias = Callable[ - Any, Any + ..., Any ] # TODO: Maybe can be more specific. Up for discussion. SteppersOperatorsType: TypeAlias = tuple[tuple[OperatorType, ...], ...] # tuple[Union[PrefactorOperatorType, StepOperatorType, NoOpType, np.floating], ...], ... @@ -82,3 +83,7 @@ RodType: TypeAlias = Type[RodBase] SystemCollectionType: TypeAlias = BaseSystemCollection AllowedContactType: TypeAlias = Union[SystemType, Type[SurfaceBase]] + +# builtin float and numpy.floating are incompatible, so define a union that +# can be used when a general float is expected +Float: TypeAlias = Union[np.floating, float] diff --git a/elastica/utils.py b/elastica/utils.py index bbf9baa2d..1b7832766 100644 --- a/elastica/utils.py +++ b/elastica/utils.py @@ -1,12 +1,17 @@ """ Handy utilities """ +from typing import Generator, Iterable, Any, Literal, TypeVar import functools import numpy as np from numpy import finfo, float64 from itertools import islice from scipy.interpolate import BSpline +from numpy.typing import NDArray + +from elastica.typing import Float + # Slower than the python3.8 isqrt implementation for small ints # python isqrt : ~130 ns @@ -47,6 +52,8 @@ def isqrt(num: int) -> int: elif num == 0: return 0 + raise ValueError("num must be a positive number") + class MaxDimension: """ @@ -54,7 +61,7 @@ class MaxDimension: """ @staticmethod - def value(): + def value() -> Literal[3]: """ Returns spatial dimension @@ -67,7 +74,7 @@ def value(): class Tolerance: @staticmethod - def atol(): + def atol() -> np.floating: """ Static absolute tolerance method @@ -78,7 +85,7 @@ def atol(): return finfo(float64).eps * 1e4 @staticmethod - def rtol(): + def rtol() -> np.floating: """ Static relative tolerance method @@ -89,7 +96,7 @@ def rtol(): return finfo(float64).eps * 1e11 -def perm_parity(lst): +def perm_parity(lst: list[int]) -> int: """ Given a permutation of the digits 0..N in order as a list, returns its parity (or sign): +1 for even parity; -1 for odd. @@ -115,7 +122,10 @@ def perm_parity(lst): return parity -def grouper(iterable, n): +_T = TypeVar("_T") + + +def grouper(iterable: Iterable[_T], n: int) -> Generator[tuple[_T, ...], None, None]: """Collect data into fixed-length chunks or blocks" Parameters @@ -144,7 +154,7 @@ def grouper(iterable, n): yield group -def extend_instance(obj, cls): +def extend_instance(obj: Any, cls: Any) -> None: """ Apply mixins to a class instance after creation @@ -170,7 +180,9 @@ def extend_instance(obj, cls): obj.__class__ = type(base_cls_name, (cls, base_cls), {}) -def _bspline(t_coeff, l_centerline=1.0): +def _bspline( + t_coeff: NDArray, l_centerline: Float = 1.0 +) -> tuple[BSpline, NDArray, NDArray]: """Generates a bspline object that plots the spline interpolant for any vector x. Optionally takes in a centerline length, set to 1.0 by default and keep_pts for keeping record of control points @@ -198,7 +210,9 @@ def _bspline(t_coeff, l_centerline=1.0): return __bspline_impl__(control_pts, t_coeff, degree) -def __bspline_impl__(x_pts, t_c, degree): +def __bspline_impl__( + x_pts: NDArray, t_c: NDArray, degree: int +) -> tuple[BSpline, NDArray, NDArray]: """""" # Update the knots