Skip to content

Commit

Permalink
Improve typehinting to files at project root
Browse files Browse the repository at this point in the history
  • Loading branch information
ankith26 committed May 10, 2024
1 parent 98e9d0e commit d294e44
Show file tree
Hide file tree
Showing 22 changed files with 731 additions and 460 deletions.
34 changes: 24 additions & 10 deletions elastica/_calculus.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
__doc__ = """ Quadrature and difference kernels """
from typing import Any, Union
import numpy as np
from numpy import zeros, empty
from numpy.typing import NDArray
from numba import njit
from elastica.reset_functions_for_block_structure._reset_ghost_vector_or_scalar import (
_reset_vector_ghost,
)
import functools

from elastica.typing import Float


@functools.lru_cache(maxsize=2)
def _get_zero_array(dim, ndim):
def _get_zero_array(dim: int, ndim: int) -> Union[float, NDArray[np.floating], None]:
if ndim == 1:
return 0.0
if ndim == 2:
return np.zeros((dim, 1))

return None


@njit(cache=True)
def _trapezoidal(array_collection):
def _trapezoidal(array_collection: NDArray[np.floating]) -> NDArray[np.floating]:
"""
Simple trapezoidal quadrature rule with zero at end-points, in a dimension agnostic way
Expand Down Expand Up @@ -63,7 +69,9 @@ def _trapezoidal(array_collection):


@njit(cache=True)
def _trapezoidal_for_block_structure(array_collection, ghost_idx):
def _trapezoidal_for_block_structure(
array_collection: NDArray[np.floating], ghost_idx: NDArray[np.integer]
) -> NDArray[np.floating]:
"""
Simple trapezoidal quadrature rule with zero at end-points, in a dimension agnostic way. This form
specifically for the block structure implementation and there is a reset function call, to reset
Expand Down Expand Up @@ -115,7 +123,9 @@ def _trapezoidal_for_block_structure(array_collection, ghost_idx):


@njit(cache=True)
def _two_point_difference(array_collection):
def _two_point_difference(
array_collection: NDArray[np.floating],
) -> NDArray[np.floating]:
"""
This function does differentiation.
Expand Down Expand Up @@ -156,7 +166,9 @@ def _two_point_difference(array_collection):


@njit(cache=True)
def _two_point_difference_for_block_structure(array_collection, ghost_idx):
def _two_point_difference_for_block_structure(
array_collection: NDArray[np.floating], ghost_idx: NDArray[np.integer]
) -> NDArray[np.floating]:
"""
This function does the differentiation, for Cosserat rod model equations. This form
specifically for the block structure implementation and there is a reset function call, to
Expand Down Expand Up @@ -207,7 +219,7 @@ def _two_point_difference_for_block_structure(array_collection, ghost_idx):


@njit(cache=True)
def _difference(vector):
def _difference(vector: NDArray[np.floating]) -> NDArray[np.floating]:
"""
This function computes difference between elements of a batch vector.
Expand Down Expand Up @@ -238,7 +250,7 @@ def _difference(vector):


@njit(cache=True)
def _average(vector):
def _average(vector: NDArray[np.floating]) -> NDArray[np.floating]:
"""
This function computes the average between elements of a vector.
Expand Down Expand Up @@ -268,7 +280,9 @@ def _average(vector):


@njit(cache=True)
def _clip_array(input_array, vmin, vmax):
def _clip_array(
input_array: NDArray[np.floating], vmin: Float, vmax: Float
) -> NDArray[np.floating]:
"""
This function clips an array values
between user defined minimum and maximum
Expand Down Expand Up @@ -304,7 +318,7 @@ def _clip_array(input_array, vmin, vmax):


@njit(cache=True)
def _isnan_check(array):
def _isnan_check(array: NDArray[Any]) -> bool:
"""
This function checks if there is any nan inside the array.
If there is nan, it returns True boolean.
Expand All @@ -324,7 +338,7 @@ def _isnan_check(array):
Python version: 2.24 µs ± 96.1 ns per loop
This version: 479 ns ± 6.49 ns per loop
"""
return np.isnan(array).any()
return bool(np.isnan(array).any())


position_difference_kernel = _difference
Expand Down
23 changes: 12 additions & 11 deletions elastica/_contact_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
)
import numba
import numpy as np
from numpy.typing import NDArray

from elastica.typing import Float

@numba.njit(cache=True)
def _calculate_contact_forces_rod_cylinder(
Expand Down Expand Up @@ -784,17 +786,16 @@ def _calculate_contact_forces_rod_plane_with_anisotropic_friction(

@numba.njit(cache=True)
def _calculate_contact_forces_cylinder_plane(
plane_origin,
plane_normal,
surface_tol,
k,
nu,
length,
position_collection,
velocity_collection,
external_forces,
):

plane_origin: NDArray[np.floating],
plane_normal: NDArray[np.floating],
surface_tol: Float,
k: Float,
nu: Float,
length: NDArray[np.floating],
position_collection: NDArray[np.floating],
velocity_collection: NDArray[np.floating],
external_forces: NDArray[np.floating],
) -> tuple[NDArray[np.floating], NDArray[np.intp]]:
# Compute plane response force
# total_forces = system.internal_forces + system.external_forces
total_forces = external_forces
Expand Down
45 changes: 33 additions & 12 deletions elastica/_linalg.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__doc__ = """ Convenient linear algebra kernels """
import numpy as np
from numpy.typing import NDArray
from numba import njit
from numpy import sqrt
import functools
Expand All @@ -8,7 +9,7 @@


@functools.lru_cache(maxsize=1)
def levi_civita_tensor(dim):
def levi_civita_tensor(dim: int) -> NDArray[np.floating]:
"""
Parameters
Expand All @@ -28,7 +29,9 @@ def levi_civita_tensor(dim):


@njit(cache=True)
def _batch_matvec(matrix_collection, vector_collection):
def _batch_matvec(
matrix_collection: NDArray[np.floating], vector_collection: NDArray[np.floating]
) -> NDArray[np.floating]:
"""
This function does batch matrix and batch vector product
Expand Down Expand Up @@ -59,7 +62,10 @@ def _batch_matvec(matrix_collection, vector_collection):


@njit(cache=True)
def _batch_matmul(first_matrix_collection, second_matrix_collection):
def _batch_matmul(
first_matrix_collection: NDArray[np.floating],
second_matrix_collection: NDArray[np.floating],
) -> NDArray[np.floating]:
"""
This is batch matrix matrix multiplication function. Only batch
of 3x3 matrices can be multiplied.
Expand Down Expand Up @@ -93,7 +99,10 @@ def _batch_matmul(first_matrix_collection, second_matrix_collection):


@njit(cache=True)
def _batch_cross(first_vector_collection, second_vector_collection):
def _batch_cross(
first_vector_collection: NDArray[np.floating],
second_vector_collection: NDArray[np.floating],
) -> NDArray[np.floating]:
"""
This function does cross product between two batch vectors.
Expand Down Expand Up @@ -133,7 +142,9 @@ def _batch_cross(first_vector_collection, second_vector_collection):


@njit(cache=True)
def _batch_vec_oneD_vec_cross(first_vector_collection, second_vector):
def _batch_vec_oneD_vec_cross(
first_vector_collection: NDArray[np.floating], second_vector: NDArray[np.floating]
) -> NDArray[np.floating]:
"""
This function does cross product between batch vector and a 1D vector.
Idea of having this function is that, for friction calculations, we dont
Expand Down Expand Up @@ -177,7 +188,9 @@ def _batch_vec_oneD_vec_cross(first_vector_collection, second_vector):


@njit(cache=True)
def _batch_dot(first_vector, second_vector):
def _batch_dot(
first_vector: NDArray[np.floating], second_vector: NDArray[np.floating]
) -> NDArray[np.floating]:
"""
This function does batch vec and batch vec dot product.
Parameters
Expand All @@ -204,7 +217,7 @@ def _batch_dot(first_vector, second_vector):


@njit(cache=True)
def _batch_norm(vector):
def _batch_norm(vector: NDArray[np.floating]) -> NDArray[np.floating]:
"""
This function computes norm of a batch vector
Parameters
Expand Down Expand Up @@ -233,7 +246,9 @@ def _batch_norm(vector):


@njit(cache=True)
def _batch_product_i_k_to_ik(vector1, vector2):
def _batch_product_i_k_to_ik(
vector1: NDArray[np.floating], vector2: NDArray[np.floating]
) -> NDArray[np.floating]:
"""
This function does outer product following 'i,k->ik'.
vector1 has shape of 3 and vector 2 has shape of blocksize
Expand Down Expand Up @@ -262,7 +277,9 @@ def _batch_product_i_k_to_ik(vector1, vector2):


@njit(cache=True)
def _batch_product_i_ik_to_k(vector1, vector2):
def _batch_product_i_ik_to_k(
vector1: NDArray[np.floating], vector2: NDArray[np.floating]
) -> NDArray[np.floating]:
"""
This function does the following product 'i,ik->k'
This function do dot product between a vector of 3 elements
Expand Down Expand Up @@ -293,7 +310,9 @@ def _batch_product_i_ik_to_k(vector1, vector2):


@njit(cache=True)
def _batch_product_k_ik_to_ik(vector1, vector2):
def _batch_product_k_ik_to_ik(
vector1: NDArray[np.floating], vector2: NDArray[np.floating]
) -> NDArray[np.floating]:
"""
This function does the following product 'k, ik->ik'
Parameters
Expand Down Expand Up @@ -322,7 +341,9 @@ def _batch_product_k_ik_to_ik(vector1, vector2):


@njit(cache=True)
def _batch_vector_sum(vector1, vector2):
def _batch_vector_sum(
vector1: NDArray[np.floating], vector2: NDArray[np.floating]
) -> NDArray[np.floating]:
"""
This function is for summing up two vectors. Although
this function is not faster than pure python implementation
Expand Down Expand Up @@ -352,7 +373,7 @@ def _batch_vector_sum(vector1, vector2):


@njit(cache=True)
def _batch_matrix_transpose(input_matrix):
def _batch_matrix_transpose(input_matrix: NDArray[np.floating]) -> NDArray[np.floating]:
"""
This function takes an batch input matrix and transpose it.
Parameters
Expand Down
Loading

0 comments on commit d294e44

Please sign in to comment.