Skip to content

Commit

Permalink
refactor: clarify iterating systems vs block_systems
Browse files Browse the repository at this point in the history
  • Loading branch information
skim0119 committed Jun 26, 2024
1 parent 9fdc9f0 commit 499ae6d
Show file tree
Hide file tree
Showing 19 changed files with 117 additions and 107 deletions.
71 changes: 38 additions & 33 deletions elastica/modules/base_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from elastica.typing import (
StaticSystemType,
SystemType,
StaticSystemType,
BlockSystemType,
SystemIdxType,
OperatorType,
OperatorCallbackType,
Expand Down Expand Up @@ -38,8 +40,10 @@ class BaseSystemCollection(MutableSequence):
----------
allowed_sys_types: tuple
Tuple of allowed type rod-like objects. Here use a base class for objects, i.e. RodBase.
_systems: list
List of rod-like objects.
systems: Callabke
Returns all system objects. Once finalize, block objects are also included.
blocks: Callable
Returns block objects. Should be called after finalize.
Note
----
Expand Down Expand Up @@ -68,8 +72,8 @@ def __init__(self) -> None:
)

# List of systems to be integrated
self._systems: list[StaticSystemType] = []
self.__final_systems: list[SystemType] = []
self.__systems: list[StaticSystemType] = []
self.__final_blocks: list[BlockSystemType] = []

# Flag Finalize: Finalizing twice will cause an error,
# but the error message is very misleading
Expand Down Expand Up @@ -99,7 +103,7 @@ def _check_type(self, sys_to_be_added: Any) -> bool:
return True

def __len__(self) -> int:
return len(self._systems)
return len(self.__systems)

@overload
def __getitem__(self, idx: int, /) -> SystemType: ...
Expand All @@ -108,22 +112,22 @@ def __getitem__(self, idx: int, /) -> SystemType: ...
def __getitem__(self, idx: slice, /) -> list[SystemType]: ...

def __getitem__(self, idx, /): # type: ignore
return self._systems[idx]
return self.__systems[idx]

def __delitem__(self, idx, /): # type: ignore
del self._systems[idx]
del self.__systems[idx]

def __setitem__(self, idx, system, /): # type: ignore
self._check_type(system)
self._systems[idx] = system
self.__systems[idx] = system

def insert(self, idx, system) -> None: # type: ignore
self._check_type(system)
self._systems.insert(idx, system)
self.__systems.insert(idx, system)

def __str__(self) -> str:
"""To be readable"""
return str(self._systems)
return str(self.__systems)

@final
def extend_allowed_types(
Expand All @@ -138,38 +142,43 @@ def override_allowed_types(
self.allowed_sys_types = allowed_types

@final
def _get_sys_idx_if_valid(
self, sys_to_be_added: "SystemType | StaticSystemType"
def get_system_index(
self, system: "SystemType | StaticSystemType"
) -> SystemIdxType:
n_systems = len(self) # Total number of systems from mixed-in class

sys_idx: SystemIdxType
if isinstance(sys_to_be_added, (int, np.int_)):
if isinstance(system, (int, np.int_)):
# 1. If they are indices themselves, check range
# This is only used for testing purposes
assert (
-n_systems <= sys_to_be_added < n_systems
), "Rod index {} exceeds number of registered rodtems".format(
sys_to_be_added
)
sys_idx = int(sys_to_be_added)
elif self._check_type(sys_to_be_added):
# 2. If they are rod objects (most likely), lookup indices
-n_systems <= system < n_systems
), "System index {} exceeds number of registered rodtems".format(system)
sys_idx = int(system)
elif self._check_type(system):
# 2. If they are system object (most likely), lookup indices
# index might have some problems : https://stackoverflow.com/a/176921
try:
sys_idx = self._systems.index(sys_to_be_added)
sys_idx = self.__systems.index(system)
except ValueError:
raise ValueError(
"Rod {} was not found, did you append it to the system?".format(
sys_to_be_added
"System {} was not found, did you append it to the system?".format(
system
)
)

return sys_idx

@final
def systems(self) -> Generator[SystemType, None, None]:
def systems(self) -> Generator[StaticSystemType, None, None]:
# assert self._finalize_flag, "The simulator is not finalized."
for system in self.__systems:
yield system

@final
def block_systems(self) -> Generator[BlockSystemType, None, None]:
# assert self._finalize_flag, "The simulator is not finalized."
for block in self.__final_systems:
for block in self.__final_blocks:
yield block

@final
Expand All @@ -184,15 +193,11 @@ def finalize(self) -> None:
assert not self._finalize_flag, "The finalize cannot be called twice."
self._finalize_flag = True

# construct memory block
self.__final_systems = construct_memory_block_structures(self._systems)
self._systems.extend(
self.__final_systems
) # FIXME: We need this to make ring-rod working.
# Construct memory block
self.__final_blocks = construct_memory_block_structures(self.__systems)
# FIXME: We need this to make ring-rod working.
# But probably need to be refactored
# TODO: try to remove the _systems list for memory optimization
# self._systems.clear()
# del self._systems
self.__systems.extend(self.__final_blocks)

# Recurrent call finalize functions for all components.
for finalize in self._feature_group_finalize:
Expand Down
4 changes: 2 additions & 2 deletions elastica/modules/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def collect_diagnostics(
-------
"""
sys_idx: SystemIdxType = self._get_sys_idx_if_valid(system)
sys_idx: SystemIdxType = self.get_system_index(system)

# Create _Constraint object, cache it and return to user
_callbacks: ModuleProtocol = _CallBack(sys_idx)
Expand All @@ -77,7 +77,7 @@ def _callback_execution(
current_step: int,
) -> None:
for sys_id, callback in self._callback_operators:
callback.make_callback(self._systems[sys_id], time, current_step)
callback.make_callback(self[sys_id], time, current_step)


class _CallBack:
Expand Down
8 changes: 4 additions & 4 deletions elastica/modules/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def connect(
"""
# For each system identified, get max dofs
sys_idx_first = self._get_sys_idx_if_valid(first_rod)
sys_idx_second = self._get_sys_idx_if_valid(second_rod)
sys_idx_first = self.get_system_index(first_rod)
sys_idx_second = self.get_system_index(second_rod)
sys_dofs_first = first_rod.n_elems
sys_dofs_second = second_rod.n_elems

Expand Down Expand Up @@ -118,9 +118,9 @@ def apply_forces_and_torques(
func = functools.partial(
apply_forces_and_torques,
connect_instance=connect_instance,
system_one=self._systems[first_sys_idx],
system_one=self[first_sys_idx],
first_connect_idx=first_connect_idx,
system_two=self._systems[second_sys_idx],
system_two=self[second_sys_idx],
second_connect_idx=second_connect_idx,
)

Expand Down
14 changes: 7 additions & 7 deletions elastica/modules/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def constrain(
-------
"""
sys_idx = self._get_sys_idx_if_valid(system)
sys_idx = self.get_system_index(system)

# Create _Constraint object, cache it and return to user
_constraint: ModuleProtocol = _Constraint(sys_idx)
Expand All @@ -73,13 +73,13 @@ def _finalize_constraints(self: SystemCollectionProtocol) -> None:
"""
from elastica._synchronize_periodic_boundary import _ConstrainPeriodicBoundaries

for block in self.systems():
for block in self.block_systems():
# append the memory block to the simulation as a system. Memory block is the final system in the simulation.
if hasattr(block, "ring_rod_flag"):
# Apply the constrain to synchronize the periodic boundaries of the memory rod. Find the memory block
# sys idx among other systems added and then apply boundary conditions.
memory_block_idx = self._get_sys_idx_if_valid(block)
block_system = cast(BlockSystemType, self._systems[memory_block_idx])
memory_block_idx = self.get_system_index(block)
block_system = cast(BlockSystemType, self[memory_block_idx])
self.constrain(block_system).using(
_ConstrainPeriodicBoundaries,
)
Expand All @@ -90,7 +90,7 @@ def _finalize_constraints(self: SystemCollectionProtocol) -> None:
# dev : the first index stores the rod index to apply the boundary condition
# to.
self._constraints_operators = [
(constraint.id(), constraint.instantiate(self._systems[constraint.id()]))
(constraint.id(), constraint.instantiate(self[constraint.id()]))
for constraint in self._constraints_list
]

Expand All @@ -109,11 +109,11 @@ def _finalize_constraints(self: SystemCollectionProtocol) -> None:

def _constrain_values(self: SystemCollectionProtocol, time: np.float64) -> None:
for sys_id, constraint in self._constraints_operators:
constraint.constrain_values(self._systems[sys_id], time)
constraint.constrain_values(self[sys_id], time)

def _constrain_rates(self: SystemCollectionProtocol, time: np.float64) -> None:
for sys_id, constraint in self._constraints_operators:
constraint.constrain_rates(self._systems[sys_id], time)
constraint.constrain_rates(self[sys_id], time)


class _Constraint:
Expand Down
12 changes: 6 additions & 6 deletions elastica/modules/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def detect_contact_between(
-------
"""
sys_idx_first = self._get_sys_idx_if_valid(first_system)
sys_idx_second = self._get_sys_idx_if_valid(second_system)
sys_idx_first = self.get_system_index(first_system)
sys_idx_second = self.get_system_index(second_system)

# Create _Contact object, cache it and return to user
_contact = _Contact(sys_idx_first, sys_idx_second)
Expand All @@ -86,17 +86,17 @@ def apply_contact(
second_sys_idx: SystemIdxType,
) -> None:
contact_instance.apply_contact(
system_one=self._systems[first_sys_idx],
system_two=self._systems[second_sys_idx],
system_one=self[first_sys_idx],
system_two=self[second_sys_idx],
)

for contact in self._contacts:
first_sys_idx, second_sys_idx = contact.id()
contact_instance = contact.instantiate()

contact_instance._check_systems_validity(
self._systems[first_sys_idx],
self._systems[second_sys_idx],
self[first_sys_idx],
self[second_sys_idx],
)
func = functools.partial(
apply_contact,
Expand Down
6 changes: 3 additions & 3 deletions elastica/modules/damping.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def dampen(self: SystemCollectionProtocol, system: RodType) -> ModuleProtocol:
-------
"""
sys_idx = self._get_sys_idx_if_valid(system)
sys_idx = self.get_system_index(system)

# Create _Damper object, cache it and return to user
_damper: ModuleProtocol = _Damper(sys_idx)
Expand All @@ -65,7 +65,7 @@ def _finalize_dampers(self: SystemCollectionProtocol) -> None:
# inplace : https://stackoverflow.com/a/1208792

self._damping_operators = [
(damper.id(), damper.instantiate(self._systems[damper.id()]))
(damper.id(), damper.instantiate(self[damper.id()]))
for damper in self._damping_list
]

Expand All @@ -78,7 +78,7 @@ def _finalize_dampers(self: SystemCollectionProtocol) -> None:

def _dampen_rates(self: SystemCollectionProtocol, time: np.float64) -> None:
for sys_id, damper in self._damping_operators:
damper.dampen_rates(self._systems[sys_id], time)
damper.dampen_rates(self[sys_id], time)


class _Damper:
Expand Down
6 changes: 3 additions & 3 deletions elastica/modules/forcing.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def add_forcing_to(
-------
"""
sys_idx = self._get_sys_idx_if_valid(system)
sys_idx = self.get_system_index(system)

# Create _Constraint object, cache it and return to user
_ext_force_torque = _ExtForceTorque(sys_idx)
Expand All @@ -73,10 +73,10 @@ def _finalize_forcing(self: SystemCollectionProtocol) -> None:
forcing_instance = external_force_and_torque.instantiate()

apply_forces = functools.partial(
forcing_instance.apply_forces, system=self._systems[sys_id]
forcing_instance.apply_forces, system=self[sys_id]
)
apply_torques = functools.partial(
forcing_instance.apply_torques, system=self._systems[sys_id]
forcing_instance.apply_torques, system=self[sys_id]
)

self._feature_group_synchronize.add_operators(
Expand Down
3 changes: 1 addition & 2 deletions elastica/modules/memory_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
RigidBodyType,
SurfaceType,
StaticSystemType,
SystemType,
SystemIdxType,
BlockSystemType,
)
Expand All @@ -22,7 +21,7 @@

def construct_memory_block_structures(
systems: list[StaticSystemType],
) -> list[SystemType]:
) -> list[BlockSystemType]:
"""
This function takes the systems (rod or rigid body) appended to the simulator class and
separates them into lists depending on if system is Cosserat rod or rigid body. Then using
Expand Down
11 changes: 6 additions & 5 deletions elastica/modules/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
OperatorFinalizeType,
StaticSystemType,
SystemType,
BlockSystemType,
ConnectionIndex,
)
from elastica.joint import FreeJoint
Expand All @@ -34,10 +35,12 @@ def id(self) -> Any: ...


class SystemCollectionProtocol(Protocol):
_systems: list[StaticSystemType]

def __len__(self) -> int: ...

def systems(self) -> Generator[StaticSystemType, None, None]: ...

def block_systems(self) -> Generator[BlockSystemType, None, None]: ...

@overload
def __getitem__(self, i: slice) -> list[SystemType]: ...
@overload
Expand Down Expand Up @@ -67,9 +70,7 @@ def apply_callbacks(self, time: np.float64, current_step: int) -> None: ...
@property
def _feature_group_finalize(self) -> list[OperatorFinalizeType]: ...

def systems(self) -> Generator[SystemType, None, None]: ...

def _get_sys_idx_if_valid(
def get_system_index(
self, sys_to_be_added: "SystemType | StaticSystemType"
) -> SystemIdxType: ...

Expand Down
3 changes: 3 additions & 0 deletions elastica/systems/analytical.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,9 @@ def __init__(self):
def systems(self):
return self._memory_blocks

def block_systems(self):
return self._memory_blocks

def __getitem__(self, idx):
return self._memory_blocks[idx]

Expand Down
Loading

0 comments on commit 499ae6d

Please sign in to comment.