Skip to content

Commit

Permalink
Merge pull request #367 from GazzolaLab/update/mypy
Browse files Browse the repository at this point in the history
Type-hinting elastica
  • Loading branch information
skim0119 authored Jun 28, 2024
2 parents a6f92ce + 92d8b91 commit a7ce95c
Show file tree
Hide file tree
Showing 101 changed files with 4,362 additions and 5,055 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ jobs:
- name: Run tests
run: |
make test
- name: Typechecking
if: ${{ startsWith(runner.os, 'macOS') }}
run: |
make mypy
report-coverage: # Report coverage from python 3.10 and mac-os. May change later
runs-on: ${{ matrix.os }}
strategy:
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ __pycache__/
# C extensions
*.so

*.swp

# Distribution / packaging
.Python
build/
Expand Down
36 changes: 22 additions & 14 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#* Variables
PYTHON := python3
PYTHONPATH := `pwd`
AUTOFLAKE8_ARGS := -r --exclude '__init__.py' --keep-pass-after-docstring
AUTOFLAKE_ARGS := -r
#* Poetry
.PHONY: poetry-download
poetry-download:
Expand Down Expand Up @@ -47,19 +47,23 @@ flake8:
poetry run flake8 --version
poetry run flake8 elastica tests

.PHONY: autoflake8-check
autoflake8-check:
poetry run autoflake8 --version
poetry run autoflake8 $(AUTOFLAKE8_ARGS) elastica tests examples
poetry run autoflake8 --check $(AUTOFLAKE8_ARGS) elastica tests examples
.PHONY: autoflake-check
autoflake-check:
poetry run autoflake --version
poetry run autoflake $(AUTOFLAKE_ARGS) elastica tests examples
poetry run autoflake --check $(AUTOFLAKE_ARGS) elastica tests examples

.PHONY: autoflake8-format
autoflake8-format:
poetry run autoflake8 --version
poetry run autoflake8 --in-place $(AUTOFLAKE8_ARGS) elastica tests examples
.PHONY: autoflake-format
autoflake-format:
poetry run autoflake --version
poetry run autoflake --in-place $(AUTOFLAKE_ARGS) elastica tests examples

.PHONY: format-codestyle
format-codestyle: black flake8
format-codestyle: black autoflake-format

.PHONY: mypy
mypy:
poetry run mypy --config-file pyproject.toml elastica

.PHONY: test
test:
Expand All @@ -74,14 +78,14 @@ test_coverage_xml:
NUMBA_DISABLE_JIT=1 poetry run pytest --cov=elastica --cov-report=xml

.PHONY: check-codestyle
check-codestyle: black-check flake8 autoflake8-check
check-codestyle: black-check flake8 autoflake-check

.PHONY: formatting
formatting: format-codestyle

.PHONY: update-dev-deps
update-dev-deps:
poetry add -D pytest@latest coverage@latest pytest-html@latest pytest-cov@latest black@latest
poetry add -D mypy@latest pytest@latest coverage@latest pytest-html@latest pytest-cov@latest black@latest

#* Cleaning
.PHONY: pycache-remove
Expand All @@ -92,6 +96,10 @@ pycache-remove:
dsstore-remove:
find . | grep -E ".DS_Store" | xargs rm -rf

.PHONY: mypycache-remove
mypycache-remove:
find . | grep -E ".mypy_cache" | xargs rm -rf

.PHONY: ipynbcheckpoints-remove
ipynbcheckpoints-remove:
find . | grep -E ".ipynb_checkpoints" | xargs rm -rf
Expand All @@ -105,7 +113,7 @@ build-remove:
rm -rf build/

.PHONY: cleanup
cleanup: pycache-remove dsstore-remove ipynbcheckpoints-remove pytestcache-remove
cleanup: pycache-remove dsstore-remove ipynbcheckpoints-remove pytestcache-remove mypycache-remove

all: format-codestyle cleanup test

Expand Down
110 changes: 108 additions & 2 deletions docs/advanced/PackageDesign.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,114 @@
# Code Design: Mixin and Composition
# Code Design

## Mixin and Composition

Elastica package follows Mixin and composition design patterns that may be unfamiliar to users. Here is a collection of references that introduce the package design.

## References
### References

- [stackoverflow discussion on Mixin](https://stackoverflow.com/questions/533631/what-is-a-mixin-and-why-are-they-useful)
- [example of Mixin: python collections](https://docs.python.org/dev/library/collections.abc.html)

## Duck Typing

Elastica package uses duck typing to allow users to define their own classes and functions. Here is a `typing.Protocol` structure that is used in the package.

### Systems

``` {mermaid}
flowchart LR
direction RL
subgraph Systems Protocol
direction RL
SLBD(SlenderBodyGeometryProtool)
SymST["SymplecticSystem:\n• KinematicStates/Rates\n• DynamicStates/Rates"]
style SymST text-align:left
ExpST["ExplicitSystem:\n• States (Unused)"]
style ExpST text-align:left
P((position\nvelocity\nacceleration\n..)) --> SLBD
subgraph StaticSystemType
Surface
Mesh
end
subgraph SystemType
direction TB
Rod
RigidBody
end
SLBD --> SymST
SystemType --> SymST
SLBD --> ExpST
SystemType --> ExpST
end
subgraph Timestepper Protocol
direction TB
StP["StepperProtocol\n• step(SystemCollection, time, dt)"]
style StP text-align:left
SymplecticStepperProtocol["SymplecticStepperProtocol\n• PositionVerlet"]
style SymplecticStepperProtocol text-align:left
ExpplicitStepperProtocol["ExpplicitStepperProtocol\n(Unused)"]
end
subgraph SystemCollection
end
SymST --> SystemCollection --> SymplecticStepperProtocol
ExpST --> SystemCollection --> ExpplicitStepperProtocol
StaticSystemType --> SystemCollection
```

### System Collection (Build memory block)

``` {mermaid}
flowchart LR
Sys((Systems))
St((Stepper))
subgraph SystemCollectionType
direction LR
StSys["StaticSystem:\n• Surface\n• Mesh"]
style StSys text-align:left
DynSys["DynamicSystem:\n• Rod\n  • CosseratRod\n• RigidBody\n  • Sphere\n  • Cylinder"]
style DynSys text-align:left
BlDynSys["BlockSystemType:\n• BlockCosseratRod\n• BlockRigidBody"]
style BlDynSys text-align:left
F{{"Feature Group (OperatorGroup):\n• Synchronize\n• Constrain values\n• Constrain rates\n• Callback"}}
style F text-align:left
end
Sys --> StSys --> F
Sys --> DynSys -->|Finalize| BlDynSys --> St
DynSys --> F <--> St
```

### System Collection (Features)

``` {mermaid}
flowchart LR
Sys((Systems))
St((Stepper))
subgraph SystemCollectionType
direction LR
StSys["StaticSystem:\n• Surface\n• Mesh"]
style StSys text-align:left
DynSys["DynamicSystem:\n• Rod\n&nbsp;&nbsp;• CosseratRod\n• RigidBody\n&nbsp;&nbsp;• Sphere\n&nbsp;&nbsp;• Cylinder"]
style DynSys text-align:left
subgraph Feature
direction LR
Forcing -->|add_forcing_to| Synchronize
Constraints -->|constrain| ConstrainValues
Constraints -->|constrain| ConstrainRates
Contact -->|detect_contact_between| Synchronize
Connection -->|connect| Synchronize
Damping -->|dampen| ConstrainRates
Callback -->|collect_diagnosis| CallbackGroup
end
end
Sys --> StSys --> Feature
Sys --> DynSys
DynSys --> Feature <--> St
```
4 changes: 4 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinx.ext.mathjax',
"sphinxcontrib.mermaid",
'numpydoc',
'myst_parser',
]
Expand Down Expand Up @@ -98,3 +99,6 @@

# -- Options for numpydoc ---------------------------------------------------
numpydoc_show_class_members = False

# -- Mermaid configuration ---------------------------------------------------
mermaid_params = ['--theme', 'neutral']
7 changes: 0 additions & 7 deletions elastica/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from collections import defaultdict
from elastica.rod.knot_theory import (
KnotTheory,
KnotTheoryCompatibleProtocol,
compute_link,
compute_twist,
compute_writhe,
Expand All @@ -19,8 +17,6 @@
GeneralConstraint,
FixedConstraint,
HelicalBucklingBC,
FreeRod,
OneEndFixedRod,
)
from elastica.external_forces import (
NoForces,
Expand All @@ -38,10 +34,8 @@
)
from elastica.joint import (
FreeJoint,
ExternalContact,
FixedJoint,
HingeJoint,
SelfContact,
)
from elastica.contact_forces import (
NoContact,
Expand Down Expand Up @@ -79,7 +73,6 @@
)
from elastica._linalg import levi_civita_tensor
from elastica.utils import isqrt
from elastica.typing import RodType, SystemType, AllowedContactType
from elastica.timestepper import (
integrate,
PositionVerlet,
Expand Down
53 changes: 27 additions & 26 deletions elastica/_calculus.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
__doc__ = """ Quadrature and difference kernels """
import numpy as np
from numpy import zeros, empty
from numpy.typing import NDArray
import numba
from numba import njit
from elastica.reset_functions_for_block_structure._reset_ghost_vector_or_scalar import (
_reset_vector_ghost,
)
import functools


@functools.lru_cache(maxsize=2)
def _get_zero_array(dim, ndim):
if ndim == 1:
return 0.0
if ndim == 2:
return np.zeros((dim, 1))


@njit(cache=True)
def _trapezoidal(array_collection):
@njit(cache=True) # type: ignore
def _trapezoidal(array_collection: NDArray[np.float64]) -> NDArray[np.float64]:
"""
Simple trapezoidal quadrature rule with zero at end-points, in a dimension agnostic way
Expand Down Expand Up @@ -62,8 +55,10 @@ def _trapezoidal(array_collection):
return temp_collection


@njit(cache=True)
def _trapezoidal_for_block_structure(array_collection, ghost_idx):
@njit(cache=True) # type: ignore
def _trapezoidal_for_block_structure(
array_collection: NDArray[np.float64], ghost_idx: NDArray[np.int32]
) -> NDArray[np.float64]:
"""
Simple trapezoidal quadrature rule with zero at end-points, in a dimension agnostic way. This form
specifically for the block structure implementation and there is a reset function call, to reset
Expand Down Expand Up @@ -114,8 +109,10 @@ def _trapezoidal_for_block_structure(array_collection, ghost_idx):
return temp_collection


@njit(cache=True)
def _two_point_difference(array_collection):
@njit(cache=True) # type: ignore
def _two_point_difference(
array_collection: NDArray[np.float64],
) -> NDArray[np.float64]:
"""
This function does differentiation.
Expand Down Expand Up @@ -155,8 +152,10 @@ def _two_point_difference(array_collection):
return temp_collection


@njit(cache=True)
def _two_point_difference_for_block_structure(array_collection, ghost_idx):
@njit(cache=True) # type: ignore
def _two_point_difference_for_block_structure(
array_collection: NDArray[np.float64], ghost_idx: NDArray[np.int32]
) -> NDArray[np.float64]:
"""
This function does the differentiation, for Cosserat rod model equations. This form
specifically for the block structure implementation and there is a reset function call, to
Expand Down Expand Up @@ -206,8 +205,8 @@ def _two_point_difference_for_block_structure(array_collection, ghost_idx):
return temp_collection


@njit(cache=True)
def _difference(vector):
@njit(cache=True) # type: ignore
def _difference(vector: NDArray[np.float64]) -> NDArray[np.float64]:
"""
This function computes difference between elements of a batch vector.
Expand Down Expand Up @@ -237,8 +236,8 @@ def _difference(vector):
return output_vector


@njit(cache=True)
def _average(vector):
@njit(cache=True) # type: ignore
def _average(vector: NDArray[np.float64]) -> NDArray[np.float64]:
"""
This function computes the average between elements of a vector.
Expand Down Expand Up @@ -267,8 +266,10 @@ def _average(vector):
return output_vector


@njit(cache=True)
def _clip_array(input_array, vmin, vmax):
@njit(cache=True) # type: ignore
def _clip_array(
input_array: NDArray[np.float64], vmin: np.float64, vmax: np.float64
) -> NDArray[np.float64]:
"""
This function clips an array values
between user defined minimum and maximum
Expand Down Expand Up @@ -303,8 +304,8 @@ def _clip_array(input_array, vmin, vmax):
return input_array


@njit(cache=True)
def _isnan_check(array):
@njit(cache=True) # type: ignore
def _isnan_check(array: NDArray) -> bool:
"""
This function checks if there is any nan inside the array.
If there is nan, it returns True boolean.
Expand All @@ -324,7 +325,7 @@ def _isnan_check(array):
Python version: 2.24 µs ± 96.1 ns per loop
This version: 479 ns ± 6.49 ns per loop
"""
return np.isnan(array).any()
return bool(np.isnan(array).any())


position_difference_kernel = _difference
Expand Down
Loading

0 comments on commit a7ce95c

Please sign in to comment.