Skip to content

Commit

Permalink
test: fix test for new synchronize ordering policy
Browse files Browse the repository at this point in the history
  • Loading branch information
skim0119 committed May 9, 2024
1 parent 15ba162 commit 8151230
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 70 deletions.
12 changes: 5 additions & 7 deletions elastica/modules/base_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,9 @@ class BaseSystemCollection(MutableSequence):
_systems: list
List of rod-like objects.
"""

"""
Developer Note
-----
Note
----
We can directly subclass a list for the
Expand Down Expand Up @@ -174,19 +172,19 @@ def finalize(self):
def synchronize(self, time: float):
# Collection call _feature_group_synchronize
for func in self._feature_group_synchronize:
func(time)
func(time=time)

def constrain_values(self, time: float):
# Collection call _feature_group_constrain_values
for func in self._feature_group_constrain_values:
func(time)
func(time=time)

def constrain_rates(self, time: float):
# Collection call _feature_group_constrain_rates
for func in self._feature_group_constrain_rates:
func(time)
func(time=time)

def apply_callbacks(self, time: float, current_step: int):
# Collection call _feature_group_callback
for func in self._feature_group_callback:
func(time, current_step)
func(time=time, current_step=current_step)
23 changes: 11 additions & 12 deletions elastica/modules/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
Provides the connections interface to connect entities (rods,
rigid bodies) using joints (see `joints.py`).
"""
import functools
import numpy as np
from elastica.joint import FreeJoint

Expand Down Expand Up @@ -60,16 +59,15 @@ def connect(
sys_dofs = [self._systems[idx].n_elems for idx in sys_idx]

# Create _Connect object, cache it and return to user
_connector = _Connect(*sys_idx, *sys_dofs)
_connector.set_index(first_connect_idx, second_connect_idx)
self._connections.append(_connector)
self._feature_group_synchronize.append_id(_connector)
_connect = _Connect(*sys_idx, *sys_dofs)
_connect.set_index(first_connect_idx, second_connect_idx)
self._connections.append(_connect)
self._feature_group_synchronize.append_id(_connect)

return _connector
return _connect

def _finalize_connections(self):
# From stored _Connect objects, instantiate the joints and store it

# dev : the first indices stores the
# (first rod index, second_rod_idx, connection_idx_on_first_rod, connection_idx_on_second_rod)
# to apply the connections to.
Expand All @@ -82,17 +80,15 @@ def _finalize_connections(self):

# FIXME: lambda t is included because OperatorType takes time as an argument
def apply_forces(time):
return functools.partial(
connect_instance.apply_forces,
connect_instance.apply_forces(
system_one=self._systems[first_sys_idx],
index_one=first_connect_idx,
system_two=self._systems[second_sys_idx],
index_two=second_connect_idx,
)

def apply_torques(time):
return functools.partial(
connect_instance.apply_torques,
connect_instance.apply_torques(
system_one=self._systems[first_sys_idx],
index_one=first_connect_idx,
system_two=self._systems[second_sys_idx],
Expand All @@ -103,6 +99,9 @@ def apply_torques(time):
connection, [apply_forces, apply_torques]
)

self._connections = []
del self._connections

# Need to finally solve CPP here, if we are doing things properly
# This is to optimize the call tree for better memory accesses
# https://brooksandrew.github.io/simpleblog/articles/intro-to-graph-optimization-solving-cpp/
Expand Down Expand Up @@ -156,7 +155,7 @@ def __init__(
def set_index(self, first_idx, second_idx):
# TODO assert range
# First check if the types of first rod idx and second rod idx variable are same.
assert type(first_idx) == type(
assert type(first_idx) is type(
second_idx
), "Type of first_connect_idx :{}".format(
type(first_idx)
Expand Down
6 changes: 4 additions & 2 deletions elastica/modules/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,16 @@ def _finalize_contact(self) -> None:
)

def apply_contact(time):
return functools.partial(
contact_instance.apply_contact,
contact_instance.apply_contact(
system_one=self._systems[first_sys_idx],
system_two=self._systems[second_sys_idx],
)

self._feature_group_synchronize.add_operators(contact, [apply_contact])

self._contacts = []
del self._contacts


class _Contact:
"""
Expand Down
6 changes: 3 additions & 3 deletions elastica/modules/feature_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ def __init__(self):
self._operator_collection: list[list[OperatorType]] = []
self._operator_ids: list[int] = []

def __iter__(self) -> Callable[[...], None]:
def __iter__(self) -> OperatorType:
if not self._operator_collection:
raise RuntimeError("Feature group is not instantiated.")
operator_chain = itertools.chain(self._operator_collection)
operator_chain = itertools.chain.from_iterable(self._operator_collection)
for operator in operator_chain:
yield operator

Expand All @@ -23,5 +23,5 @@ def append_id(self, feature):
self._operator_collection.append([])

def add_operators(self, feature, operators: list[OperatorType]):
idx = self._operator_ids.index(feature)
idx = self._operator_ids.index(id(feature))
self._operator_collection[idx].extend(operators)
3 changes: 3 additions & 0 deletions elastica/modules/forcing.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def _finalize_forcing(self):
ext_force_torque, [apply_forces, apply_torques]
)

self._ext_forces_torques = []
del self._ext_forces_torques


class _ExtForceTorque:
"""
Expand Down
13 changes: 12 additions & 1 deletion tests/test_modules/test_base_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,18 @@ def test_forcing(self, load_collection, legal_forces):
simulator_class.add_forcing_to(rod).using(legal_forces)
simulator_class.finalize()
# After finalize check if the created forcing object is instance of the class we have given.
assert isinstance(simulator_class._ext_forces_torques[-1][-1], legal_forces)
assert isinstance(
simulator_class._feature_group_synchronize._operator_collection[-1][
-1
].func.__self__,
legal_forces,
)
assert isinstance(
simulator_class._feature_group_synchronize._operator_collection[-1][
-2
].func.__self__,
legal_forces,
)

# TODO: this is a dummy test for synchronize find a better way to test them
simulator_class.synchronize(time=0)
Expand Down
43 changes: 20 additions & 23 deletions tests/test_modules/test_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def test_call_without_setting_connect_throws_runtime_error(self, load_connect):
connect = load_connect

with pytest.raises(RuntimeError) as excinfo:
connect()
connect.instantiate()
assert "No connections provided" in str(excinfo.value)

def test_call_improper_args_throws(self, load_connect):
Expand All @@ -173,7 +173,7 @@ def mock_init(self, *args, **kwargs):

# Actual test is here, this should not throw
with pytest.raises(TypeError) as excinfo:
_ = connect()
_ = connect.instantiate()
assert (
r"Unable to construct connection class.\nDid you provide all necessary joint properties?"
== str(excinfo.value)
Expand Down Expand Up @@ -327,21 +327,18 @@ def mock_init(self, *args, **kwargs):

def test_connect_finalize_correctness(self, load_rod_with_connects):
system_collection_with_connections, connect_cls = load_rod_with_connects
connect = system_collection_with_connections._connections[0]
assert connect._connect_cls == connect_cls

system_collection_with_connections._finalize_connections()
assert (
system_collection_with_connections._feature_group_synchronize._operator_ids[
0
]
== id(connect)
)

for (
fidx,
sidx,
fconnect,
sconnect,
connect,
) in system_collection_with_connections._connections:
assert type(fidx) is int
assert type(sidx) is int
assert fconnect is None
assert sconnect is None
assert type(connect) is connect_cls
assert not hasattr(system_collection_with_connections, "_connections")

@pytest.fixture
def load_rod_with_connects_and_indices(self, load_system_with_connects):
Expand Down Expand Up @@ -392,17 +389,17 @@ def test_connect_call_on_systems(self, load_rod_with_connects_and_indices):
system_collection_with_connections_and_indices,
connect_cls,
) = load_rod_with_connects_and_indices
mock_connections = [
c for c in system_collection_with_connections_and_indices._connections
]

system_collection_with_connections_and_indices._finalize_connections()
system_collection_with_connections_and_indices._call_connections()

for (
fidx,
sidx,
fconnect,
sconnect,
connect,
) in system_collection_with_connections_and_indices._connections:
system_collection_with_connections_and_indices.synchronize(0)

for connection in mock_connections:
fidx, sidx, fconnect, sconnect = connection.id()
connect = connection.instantiate()

end_distance_vector = (
system_collection_with_connections_and_indices._systems[
sidx
Expand Down
27 changes: 14 additions & 13 deletions tests/test_modules/test_contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_call_without_setting_contact_throws_runtime_error(self, load_contact):
contact = load_contact

with pytest.raises(RuntimeError) as excinfo:
contact()
contact.instantiate()
assert "No contacts provided to to establish contact between rod-like object id {0} and {1}, but a Contact was intended as per code. Did you forget to call the `using` method?".format(
*contact.id()
) == str(
Expand All @@ -75,7 +75,7 @@ def mock_init(self, *args, **kwargs):

# Actual test is here, this should not throw
with pytest.raises(TypeError) as excinfo:
_ = contact()
_ = contact.instantiate()
assert (
r"Unable to construct contact class.\nDid you provide all necessary contact properties?"
== str(excinfo.value)
Expand Down Expand Up @@ -260,13 +260,15 @@ def mock_init(self, *args, **kwargs):

def test_contact_finalize_correctness(self, load_rod_with_contacts):
system_collection_with_contacts, contact_cls = load_rod_with_contacts
contact = system_collection_with_contacts._contacts[0].instantiate()
fidx, sidx = system_collection_with_contacts._contacts[0].id()

system_collection_with_contacts._finalize_contact()

for fidx, sidx, contact in system_collection_with_contacts._contacts:
assert type(fidx) is int
assert type(sidx) is int
assert type(contact) is contact_cls
assert not hasattr(system_collection_with_contacts, "_contacts")
assert type(fidx) is int
assert type(sidx) is int
assert type(contact) is contact_cls

@pytest.fixture
def load_contact_objects_with_incorrect_order(self, load_system_with_contacts):
Expand Down Expand Up @@ -339,19 +341,18 @@ def load_system_with_rods_in_contact(self, load_system_with_contacts):
return system_collection_with_rods_in_contact

def test_contact_call_on_systems(self, load_system_with_rods_in_contact):
from elastica.contact_forces import _calculate_contact_forces_rod_rod

system_collection_with_rods_in_contact = load_system_with_rods_in_contact
mock_contacts = [c for c in system_collection_with_rods_in_contact._contacts]

system_collection_with_rods_in_contact._finalize_contact()
system_collection_with_rods_in_contact._call_contacts(time=0)
system_collection_with_rods_in_contact.synchronize(time=0)

from elastica.contact_forces import _calculate_contact_forces_rod_rod
for _contact in mock_contacts:
fidx, sidx = _contact.id()
contact = _contact.instantiate()

for (
fidx,
sidx,
contact,
) in system_collection_with_rods_in_contact._contacts:
system_one = system_collection_with_rods_in_contact._systems[fidx]
system_two = system_collection_with_rods_in_contact._systems[sidx]
external_forces_system_one = np.zeros_like(system_one.external_forces)
Expand Down
23 changes: 14 additions & 9 deletions tests/test_modules/test_forcing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_call_without_setting_forcing_throws_runtime_error(self, load_forcing):
forcing = load_forcing

with pytest.raises(RuntimeError) as excinfo:
forcing(None) # None is the rod/system parameter
forcing.instantiate() # None is the rod/system parameter
assert "No forcing" in str(excinfo.value)

def test_call_improper_args_throws(self, load_forcing):
Expand All @@ -62,7 +62,7 @@ def mock_init(self, *args, **kwargs):

# Actual test is here, this should not throw
with pytest.raises(TypeError) as excinfo:
_ = forcing()
_ = forcing.instantiate()
assert "Unable to construct" in str(excinfo.value)


Expand Down Expand Up @@ -166,7 +166,7 @@ def mock_init(self, *args, **kwargs):

return scwf, MockForcing

def test_friction_plane_forcing_class_sorting(self, load_system_with_forcings):
def test_friction_plane_forcing_class(self, load_system_with_forcings):

scwf = load_system_with_forcings

Expand Down Expand Up @@ -196,19 +196,24 @@ def mock_init(self, *args, **kwargs):
)
scwf.add_forcing_to(1).using(MockForcing, 2, 42) # index based forcing

# Now check if the Anisotropic friction and the MockForcing are in the list
assert scwf._ext_forces_torques[-1]._forcing_cls == MockForcing
assert scwf._ext_forces_torques[-2]._forcing_cls == AnisotropicFrictionalPlane
scwf._finalize_forcing()

# Now check if the Anisotropic friction is the last forcing class
assert isinstance(scwf._ext_forces_torques[-1][-1], AnisotropicFrictionalPlane)
assert not hasattr(scwf, "_ext_forces_torques")

def test_constrain_finalize_correctness(self, load_rod_with_forcings):
scwf, forcing_cls = load_rod_with_forcings
forcing_features = [f for f in scwf._ext_forces_torques]

scwf._finalize_forcing()
assert not hasattr(scwf, "_ext_forces_torques")

for x, y in scwf._ext_forces_torques:
assert type(x) is int
assert type(y) is forcing_cls
for _forcing in forcing_features:
x = _forcing.id()
y = _forcing.instantiate()
assert isinstance(x, int)
assert isinstance(y, forcing_cls)

@pytest.mark.xfail
def test_constrain_finalize_sorted(self, load_rod_with_forcings):
Expand Down

0 comments on commit 8151230

Please sign in to comment.