Skip to content

Commit

Permalink
update: operation order determined by user's definition
Browse files Browse the repository at this point in the history
  • Loading branch information
skim0119 committed May 9, 2024
1 parent 4d42fbe commit f6d4b97
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 100 deletions.
44 changes: 17 additions & 27 deletions elastica/modules/base_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@
interface (i.e. works with symplectic or explicit routines `timestepper.py`.)
"""
from typing import Iterable, Callable, AnyStr
from elastica.typing import OperatorType

from collections.abc import MutableSequence

from elastica.rod import RodBase
from elastica.rigidbody import RigidBodyBase
from elastica.surface import SurfaceBase
from elastica.modules.memory_block import construct_memory_block_structures
from elastica._synchronize_periodic_boundary import _ConstrainPeriodicBoundaries

from .memory_block import construct_memory_block_structures
from .feature_group import FeatureGroupFIFO


class BaseSystemCollection(MutableSequence):
"""
Expand Down Expand Up @@ -45,13 +48,11 @@ def __init__(self):
# Collection of functions. Each group is executed as a collection at the different steps.
# Each component (Forcing, Connection, etc.) registers the executable (callable) function
# in the group that that needs to be executed. These should be initialized before mixin.
self._feature_group_synchronize: Iterable[Callable[[float], None]] = []
self._feature_group_constrain_values: Iterable[Callable[[float], None]] = []
self._feature_group_constrain_rates: Iterable[Callable[[float], None]] = []
self._feature_group_callback: Iterable[Callable[[float, int, AnyStr], None]] = (
[]
)
self._feature_group_finalize: Iterable[Callable] = []
self._feature_group_synchronize: Iterable[OperatorType] = FeatureGroupFIFO()
self._feature_group_constrain_values: Iterable[OperatorType] = []
self._feature_group_constrain_rates: Iterable[OperatorType] = []
self._feature_group_callback: Iterable[OperatorCallbackType] = []
self._feature_group_finalize: Iterable[OpeatorFinalizeType] = []
# We need to initialize our mixin classes
super(BaseSystemCollection, self).__init__()
# List of system types/bases that are allowed
Expand Down Expand Up @@ -169,34 +170,23 @@ def finalize(self):

# Toggle the finalize_flag
self._finalize_flag = True
# sort _feature_group_synchronize so that _call_contacts is at the end
_call_contacts_index = []
for idx, feature in enumerate(self._feature_group_synchronize):
if feature.__name__ == "_call_contacts":
_call_contacts_index.append(idx)

# Move to the _call_contacts to the end of the _feature_group_synchronize list.
for index in _call_contacts_index:
self._feature_group_synchronize.append(
self._feature_group_synchronize.pop(index)
)

def synchronize(self, time: float):
# Collection call _feature_group_synchronize
for feature in self._feature_group_synchronize:
feature(time)
for func in self._feature_group_synchronize:
func(time)

def constrain_values(self, time: float):
# Collection call _feature_group_constrain_values
for feature in self._feature_group_constrain_values:
feature(time)
for func in self._feature_group_constrain_values:
func(time)

def constrain_rates(self, time: float):
# Collection call _feature_group_constrain_rates
for feature in self._feature_group_constrain_rates:
feature(time)
for func in self._feature_group_constrain_rates:
func(time)

def apply_callbacks(self, time: float, current_step: int):
# Collection call _feature_group_callback
for feature in self._feature_group_callback:
feature(time, current_step)
for func in self._feature_group_callback:
func(time, current_step)
58 changes: 29 additions & 29 deletions elastica/modules/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Provides the connections interface to connect entities (rods,
rigid bodies) using joints (see `joints.py`).
"""
import functools
import numpy as np
from elastica.joint import FreeJoint

Expand All @@ -24,7 +25,6 @@ class Connections:
def __init__(self):
self._connections = []
super(Connections, self).__init__()
self._feature_group_synchronize.append(self._call_connections)
self._feature_group_finalize.append(self._finalize_connections)

def connect(
Expand Down Expand Up @@ -63,6 +63,7 @@ def connect(
_connector = _Connect(*sys_idx, *sys_dofs)
_connector.set_index(first_connect_idx, second_connect_idx)
self._connections.append(_connector)
self._feature_group_synchronize.append_id(_connector)

return _connector

Expand All @@ -71,38 +72,37 @@ def _finalize_connections(self):

# dev : the first indices stores the
# (first rod index, second_rod_idx, connection_idx_on_first_rod, connection_idx_on_second_rod)
# to apply the connections to
# Technically we can use another array but it its one more book-keeping
# step. Being lazy, I put them both in the same array
self._connections[:] = [
(*connection.id(), connection()) for connection in self._connections
]
# to apply the connections to.

for connection in self._connections:
first_sys_idx, second_sys_idx, first_connect_idx, second_connect_idx = (
connection.id()
)
connect_instance = connection.instantiate()

# FIXME: lambda t is included because OperatorType takes time as an argument
apply_forces = lambda t: functools.partial(
connect_instance.apply_forces,
system_one=self._systems[first_sys_idx],
index_one=first_connect_idx,
system_two=self._systems[second_sys_idx],
index_two=second_connect_idx,
)
apply_torques = lambda t: functools.partial(
connect_instance.apply_torques,
system_one=self._systems[first_sys_idx],
index_one=first_connect_idx,
system_two=self._systems[second_sys_idx],
index_two=second_connect_idx,
)
self._feature_group_synchronize.add_operators(
connection, [apply_forces, apply_torques]
)

# Need to finally solve CPP here, if we are doing things properly
# This is to optimize the call tree for better memory accesses
# https://brooksandrew.github.io/simpleblog/articles/intro-to-graph-optimization-solving-cpp/

def _call_connections(self, *args, **kwargs):
for (
first_sys_idx,
second_sys_idx,
first_connect_idx,
second_connect_idx,
connection,
) in self._connections:
connection.apply_forces(
self._systems[first_sys_idx],
first_connect_idx,
self._systems[second_sys_idx],
second_connect_idx,
)
connection.apply_torques(
self._systems[first_sys_idx],
first_connect_idx,
self._systems[second_sys_idx],
second_connect_idx,
)


class _Connect:
"""
Expand Down Expand Up @@ -265,7 +265,7 @@ def id(self):
self.second_sys_connection_idx,
)

def __call__(self, *args, **kwargs):
def instantiate(self):
if not self._connect_cls:
raise RuntimeError(
"No connections provided to link rod id {0}"
Expand Down
34 changes: 13 additions & 21 deletions elastica/modules/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Provides the contact interface to apply contact forces between objects
(rods, rigid bodies, surfaces).
"""

import functools
from elastica.typing import SystemType, AllowedContactType


Expand All @@ -23,7 +23,6 @@ class Contact:
def __init__(self):
self._contacts = []
super(Contact, self).__init__()
self._feature_group_synchronize.append(self._call_contacts)
self._feature_group_finalize.append(self._finalize_contact)

def detect_contact_between(
Expand Down Expand Up @@ -51,6 +50,7 @@ def detect_contact_between(
# Create _Contact object, cache it and return to user
_contact = _Contact(*sys_idx)
self._contacts.append(_contact)
self._feature_group_synchronize.append_id(_contact)

return _contact

Expand All @@ -61,29 +61,21 @@ def _finalize_contact(self) -> None:
# to apply the contacts to
# Technically we can use another array but it its one more book-keeping
# step. Being lazy, I put them both in the same array
self._contacts[:] = [(*contact.id(), contact()) for contact in self._contacts]

# check contact order
for (
first_sys_idx,
second_sys_idx,
contact,
) in self._contacts:
contact._check_systems_validity(
for contact in self._contacts:
first_sys_idx, second_sys_idx = contact.id()
contact_instance = contact.instantiate()

contact_instance._check_systems_validity(
self._systems[first_sys_idx],
self._systems[second_sys_idx],
)

def _call_contacts(self, time: float):
for (
first_sys_idx,
second_sys_idx,
contact,
) in self._contacts:
contact.apply_contact(
self._systems[first_sys_idx],
self._systems[second_sys_idx],
apply_contact = lambda t: functools.partial(
contact_instance.apply_contact,
system_one=self._systems[first_sys_idx],
system_two=self._systems[second_sys_idx],
)
self._feature_group_synchronize.add_operators(contact, [apply_contact])


class _Contact:
Expand Down Expand Up @@ -153,7 +145,7 @@ def id(self):
self.second_sys_idx,
)

def __call__(self, *args, **kwargs):
def instantiate(self, *args, **kwargs):
if not self._contact_cls:
raise RuntimeError(
"No contacts provided to to establish contact between rod-like object id {0}"
Expand Down
27 changes: 27 additions & 0 deletions elastica/modules/feature_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Callable
from elastica.typing import OperatorType

from collection.abc import Iterable

import itertools


class FeatureGroupFIFO(Iterable):
def __init__(self):
self._operator_collection: list[list[OperatorType]] = []
self._operator_ids: list[int] = []

def __iter__(self) -> Callable[[...], None]:
if not self._operator_collection:
raise RuntimeError("Feature group is not instantiated.")
operator_chain = itertools.chain(self._operator_collection)
for operator in operator_chain:
yield operator

def append_id(self, feature):
self._operator_ids.append(id(feature))
self._operator_collection.append([])

def add_operators(self, feature, operators: list[OperatorType]):
idx = self._operator_idx.index(feature)
self._operator_collection[idx].extend(operators)
41 changes: 19 additions & 22 deletions elastica/modules/forcing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Provides the forcing interface to apply forces and torques to rod-like objects
(external point force, muscle torques, etc).
"""
import functools
from elastica.interaction import AnisotropicFrictionalPlane


Expand All @@ -23,7 +24,6 @@ class Forcing:
def __init__(self):
self._ext_forces_torques = []
super(Forcing, self).__init__()
self._feature_group_synchronize.append(self._call_ext_forces_torques)
self._feature_group_finalize.append(self._finalize_forcing)

def add_forcing_to(self, system):
Expand All @@ -46,6 +46,7 @@ def add_forcing_to(self, system):
# Create _Constraint object, cache it and return to user
_ext_force_torque = _ExtForceTorque(sys_idx)
self._ext_forces_torques.append(_ext_force_torque)
self._feature_group_synchronize.append_id(_ext_force_torque)

return _ext_force_torque

Expand All @@ -54,21 +55,23 @@ def _finalize_forcing(self):
# inplace : https://stackoverflow.com/a/1208792

# dev : the first index stores the rod index to apply the boundary condition
# to. Technically we can use another array but it its one more book-keeping
# step. Being lazy, I put them both in the same array
self._ext_forces_torques[:] = [
(ext_force_torque.id(), ext_force_torque())
for ext_force_torque in self._ext_forces_torques
]

# Sort from lowest id to highest id for potentially better memory access
# _ext_forces_torques contains list of tuples. First element of tuple is
# rod number and following elements are the type of boundary condition such as
# [(0, NoForces, GravityForces), (1, UniformTorques), ... ]
# Thus using lambda we iterate over the list of tuples and use rod number (x[0])
# to sort _ext_forces_torques.
self._ext_forces_torques.sort(key=lambda x: x[0])
# to.
for ext_force_torque in self._ext_forces_torques:
sys_id = ext_force_torque.id()
forcing_instance = ext_force_torque.instantiate()

apply_forces = functools.partial(
forcing_instance.apply_forces, system=self._systems[sys_id]
)
apply_torques = functools.partial(
forcing_instance.apply_torques, system=self._systems[sys_id]
)

self._feature_group_synchronize.add_operators(
ext_force_torque, [apply_forces, apply_torques]
)

# TODO: remove: we decided to let user to fully decide the order of operations
# Find if there are any friction plane forcing, if add them to the end of the list,
# since friction planes uses external forces.
friction_plane_index = []
Expand All @@ -80,12 +83,6 @@ def _finalize_forcing(self):
for index in friction_plane_index:
self._ext_forces_torques.append(self._ext_forces_torques.pop(index))

def _call_ext_forces_torques(self, time, *args, **kwargs):
for sys_id, ext_force_torque in self._ext_forces_torques:
ext_force_torque.apply_forces(self._systems[sys_id], time, *args, **kwargs)
ext_force_torque.apply_torques(self._systems[sys_id], time, *args, **kwargs)
# TODO Apply torque, see if necessary


class _ExtForceTorque:
"""
Expand Down Expand Up @@ -146,7 +143,7 @@ def using(self, forcing_cls, *args, **kwargs):
def id(self):
return self._sys_idx

def __call__(self, *args, **kwargs):
def instantiate(self):
"""Constructs a constraint after checks
Parameters
Expand Down
6 changes: 5 additions & 1 deletion elastica/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@
from elastica.rigidbody import RigidBodyBase
from elastica.surface import SurfaceBase

from typing import Type, Union
from typing import Type, Union, TypingAlias, Callable

RodType = Type[RodBase]
SystemType = Union[RodType, Type[RigidBodyBase]]
AllowedContactType = Union[SystemType, Type[SurfaceBase]]

OperatorType: TypingAlias = Callable[[float], None]
OperatorCallbackType: TypingAlias = Callable[[float, int], None]
OperatorFinalizeType: TypingAlias = Callable

0 comments on commit f6d4b97

Please sign in to comment.