Skip to content

Commit

Permalink
Improve typehinting at root and rod directory
Browse files Browse the repository at this point in the history
  • Loading branch information
ankith26 committed May 14, 2024
1 parent 63705a7 commit 89bfbc4
Show file tree
Hide file tree
Showing 22 changed files with 787 additions and 493 deletions.
32 changes: 22 additions & 10 deletions elastica/_calculus.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
__doc__ = """ Quadrature and difference kernels """
from typing import Any, Union
import numpy as np
from numpy import zeros, empty
from numpy.typing import NDArray
from numba import njit
from elastica.reset_functions_for_block_structure._reset_ghost_vector_or_scalar import (
_reset_vector_ghost,
Expand All @@ -9,15 +11,17 @@


@functools.lru_cache(maxsize=2)
def _get_zero_array(dim, ndim):
def _get_zero_array(dim: int, ndim: int) -> Union[float, NDArray[np.floating], None]:
if ndim == 1:
return 0.0
if ndim == 2:
return np.zeros((dim, 1))

return None


@njit(cache=True)
def _trapezoidal(array_collection):
def _trapezoidal(array_collection: NDArray[np.floating]) -> NDArray[np.floating]:
"""
Simple trapezoidal quadrature rule with zero at end-points, in a dimension agnostic way
Expand Down Expand Up @@ -63,7 +67,9 @@ def _trapezoidal(array_collection):


@njit(cache=True)
def _trapezoidal_for_block_structure(array_collection, ghost_idx):
def _trapezoidal_for_block_structure(
array_collection: NDArray[np.floating], ghost_idx: NDArray[np.integer]
) -> NDArray[np.floating]:
"""
Simple trapezoidal quadrature rule with zero at end-points, in a dimension agnostic way. This form
specifically for the block structure implementation and there is a reset function call, to reset
Expand Down Expand Up @@ -115,7 +121,9 @@ def _trapezoidal_for_block_structure(array_collection, ghost_idx):


@njit(cache=True)
def _two_point_difference(array_collection):
def _two_point_difference(
array_collection: NDArray[np.floating],
) -> NDArray[np.floating]:
"""
This function does differentiation.
Expand Down Expand Up @@ -156,7 +164,9 @@ def _two_point_difference(array_collection):


@njit(cache=True)
def _two_point_difference_for_block_structure(array_collection, ghost_idx):
def _two_point_difference_for_block_structure(
array_collection: NDArray[np.floating], ghost_idx: NDArray[np.integer]
) -> NDArray[np.floating]:
"""
This function does the differentiation, for Cosserat rod model equations. This form
specifically for the block structure implementation and there is a reset function call, to
Expand Down Expand Up @@ -207,7 +217,7 @@ def _two_point_difference_for_block_structure(array_collection, ghost_idx):


@njit(cache=True)
def _difference(vector):
def _difference(vector: NDArray[np.floating]) -> NDArray[np.floating]:
"""
This function computes difference between elements of a batch vector.
Expand Down Expand Up @@ -238,7 +248,7 @@ def _difference(vector):


@njit(cache=True)
def _average(vector):
def _average(vector: NDArray[np.floating]) -> NDArray[np.floating]:
"""
This function computes the average between elements of a vector.
Expand Down Expand Up @@ -268,7 +278,9 @@ def _average(vector):


@njit(cache=True)
def _clip_array(input_array, vmin, vmax):
def _clip_array(
input_array: NDArray[np.floating], vmin: np.floating, vmax: np.floating
) -> NDArray[np.floating]:
"""
This function clips an array values
between user defined minimum and maximum
Expand Down Expand Up @@ -304,7 +316,7 @@ def _clip_array(input_array, vmin, vmax):


@njit(cache=True)
def _isnan_check(array):
def _isnan_check(array: NDArray[Any]) -> bool:
"""
This function checks if there is any nan inside the array.
If there is nan, it returns True boolean.
Expand All @@ -324,7 +336,7 @@ def _isnan_check(array):
Python version: 2.24 µs ± 96.1 ns per loop
This version: 479 ns ± 6.49 ns per loop
"""
return np.isnan(array).any()
return bool(np.isnan(array).any())


position_difference_kernel = _difference
Expand Down
22 changes: 11 additions & 11 deletions elastica/_contact_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
import numba
import numpy as np
from numpy.typing import NDArray


@numba.njit(cache=True)
Expand Down Expand Up @@ -784,17 +785,16 @@ def _calculate_contact_forces_rod_plane_with_anisotropic_friction(

@numba.njit(cache=True)
def _calculate_contact_forces_cylinder_plane(
plane_origin,
plane_normal,
surface_tol,
k,
nu,
length,
position_collection,
velocity_collection,
external_forces,
):

plane_origin: NDArray[np.floating],
plane_normal: NDArray[np.floating],
surface_tol: np.floating,
k: np.floating,
nu: np.floating,
length: NDArray[np.floating],
position_collection: NDArray[np.floating],
velocity_collection: NDArray[np.floating],
external_forces: NDArray[np.floating],
) -> tuple[NDArray[np.floating], NDArray[np.intp]]:
# Compute plane response force
# total_forces = system.internal_forces + system.external_forces
total_forces = external_forces
Expand Down
45 changes: 33 additions & 12 deletions elastica/_linalg.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__doc__ = """ Convenient linear algebra kernels """
import numpy as np
from numpy.typing import NDArray
from numba import njit
from numpy import sqrt
import functools
Expand All @@ -8,7 +9,7 @@


@functools.lru_cache(maxsize=1)
def levi_civita_tensor(dim):
def levi_civita_tensor(dim: int) -> NDArray[np.floating]:
"""
Parameters
Expand All @@ -28,7 +29,9 @@ def levi_civita_tensor(dim):


@njit(cache=True)
def _batch_matvec(matrix_collection, vector_collection):
def _batch_matvec(
matrix_collection: NDArray[np.floating], vector_collection: NDArray[np.floating]
) -> NDArray[np.floating]:
"""
This function does batch matrix and batch vector product
Expand Down Expand Up @@ -59,7 +62,10 @@ def _batch_matvec(matrix_collection, vector_collection):


@njit(cache=True)
def _batch_matmul(first_matrix_collection, second_matrix_collection):
def _batch_matmul(
first_matrix_collection: NDArray[np.floating],
second_matrix_collection: NDArray[np.floating],
) -> NDArray[np.floating]:
"""
This is batch matrix matrix multiplication function. Only batch
of 3x3 matrices can be multiplied.
Expand Down Expand Up @@ -93,7 +99,10 @@ def _batch_matmul(first_matrix_collection, second_matrix_collection):


@njit(cache=True)
def _batch_cross(first_vector_collection, second_vector_collection):
def _batch_cross(
first_vector_collection: NDArray[np.floating],
second_vector_collection: NDArray[np.floating],
) -> NDArray[np.floating]:
"""
This function does cross product between two batch vectors.
Expand Down Expand Up @@ -133,7 +142,9 @@ def _batch_cross(first_vector_collection, second_vector_collection):


@njit(cache=True)
def _batch_vec_oneD_vec_cross(first_vector_collection, second_vector):
def _batch_vec_oneD_vec_cross(
first_vector_collection: NDArray[np.floating], second_vector: NDArray[np.floating]
) -> NDArray[np.floating]:
"""
This function does cross product between batch vector and a 1D vector.
Idea of having this function is that, for friction calculations, we dont
Expand Down Expand Up @@ -177,7 +188,9 @@ def _batch_vec_oneD_vec_cross(first_vector_collection, second_vector):


@njit(cache=True)
def _batch_dot(first_vector, second_vector):
def _batch_dot(
first_vector: NDArray[np.floating], second_vector: NDArray[np.floating]
) -> NDArray[np.floating]:
"""
This function does batch vec and batch vec dot product.
Parameters
Expand All @@ -204,7 +217,7 @@ def _batch_dot(first_vector, second_vector):


@njit(cache=True)
def _batch_norm(vector):
def _batch_norm(vector: NDArray[np.floating]) -> NDArray[np.floating]:
"""
This function computes norm of a batch vector
Parameters
Expand Down Expand Up @@ -233,7 +246,9 @@ def _batch_norm(vector):


@njit(cache=True)
def _batch_product_i_k_to_ik(vector1, vector2):
def _batch_product_i_k_to_ik(
vector1: NDArray[np.floating], vector2: NDArray[np.floating]
) -> NDArray[np.floating]:
"""
This function does outer product following 'i,k->ik'.
vector1 has shape of 3 and vector 2 has shape of blocksize
Expand Down Expand Up @@ -262,7 +277,9 @@ def _batch_product_i_k_to_ik(vector1, vector2):


@njit(cache=True)
def _batch_product_i_ik_to_k(vector1, vector2):
def _batch_product_i_ik_to_k(
vector1: NDArray[np.floating], vector2: NDArray[np.floating]
) -> NDArray[np.floating]:
"""
This function does the following product 'i,ik->k'
This function do dot product between a vector of 3 elements
Expand Down Expand Up @@ -293,7 +310,9 @@ def _batch_product_i_ik_to_k(vector1, vector2):


@njit(cache=True)
def _batch_product_k_ik_to_ik(vector1, vector2):
def _batch_product_k_ik_to_ik(
vector1: NDArray[np.floating], vector2: NDArray[np.floating]
) -> NDArray[np.floating]:
"""
This function does the following product 'k, ik->ik'
Parameters
Expand Down Expand Up @@ -322,7 +341,9 @@ def _batch_product_k_ik_to_ik(vector1, vector2):


@njit(cache=True)
def _batch_vector_sum(vector1, vector2):
def _batch_vector_sum(
vector1: NDArray[np.floating], vector2: NDArray[np.floating]
) -> NDArray[np.floating]:
"""
This function is for summing up two vectors. Although
this function is not faster than pure python implementation
Expand Down Expand Up @@ -352,7 +373,7 @@ def _batch_vector_sum(vector1, vector2):


@njit(cache=True)
def _batch_matrix_transpose(input_matrix):
def _batch_matrix_transpose(input_matrix: NDArray[np.floating]) -> NDArray[np.floating]:
"""
This function takes an batch input matrix and transpose it.
Parameters
Expand Down
Loading

0 comments on commit 89bfbc4

Please sign in to comment.