Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 131 additions & 26 deletions ingenialink/drive_context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,16 @@ def __init__(
set(complete_access_objects) if isinstance(complete_access_objects, list) else set()
)

self._original_register_values: dict[int, dict[str, Union[int, float, str, bytes]]] = {}
self._original_register_values: OrderedDict[
tuple[int, str], Union[int, float, str, bytes]
] = OrderedDict()

self._original_canopen_object_values: dict[CanOpenObject, bytes] = {}

# Key: (axis, uid), value
self._registers_changed = OrderedDict[tuple[int, str], Union[int, float, str, bytes]]()

self._objects_changed = set[CanOpenObject]()
self._objects_changed: dict[CanOpenObject, bytes] = {}

def _register_update_callback(
self,
Expand All @@ -104,14 +106,14 @@ def _register_update_callback(
return
if uid in self._do_not_restore_registers:
return
if uid not in self._original_register_values[register.subnode]:
dict_key = (register.subnode, uid)
if dict_key not in self._original_register_values:
return

# Check if the new value is different from the previous one
dict_key = (register.subnode, uid)
if dict_key in self._registers_changed:
previous_value = self._registers_changed[dict_key]
previous_value = self._original_register_values[register.subnode][uid]
previous_value = self._original_register_values[dict_key]
current_value = value if value is not None else previous_value
if current_value == previous_value:
return
Expand Down Expand Up @@ -155,15 +157,20 @@ def _complete_access_callback(
self._register_update_callback(servo=servo, register=register, value=value)
return

self._objects_changed.add(obj)
# Store the object as changed (actual value will be determined during restoration)
self._objects_changed[obj] = b"" # Placeholder, actual restore uses original value

logger.debug(f"{id(self)}: Object {obj.uid} changed using complete access to {value!r}.")

def _store_register_data(self) -> None:
"""Saves the value of all registers."""
def _store_register_data(self) -> OrderedDict[tuple[int, str], Union[int, float, str, bytes]]:
"""Reads and returns the value of all registers.

Returns:
OrderedDict mapping (axis, uid) tuple to register value.
"""
register_values: OrderedDict[tuple[int, str], Union[int, float, str, bytes]] = OrderedDict()
axes = list(self.drive.dictionary.subnodes) if self._axis is None else [self._axis]
for axis in axes:
self._original_register_values[axis] = {}
for uid, register in self.drive.dictionary.registers(subnode=axis).items():
if uid in self._do_not_restore_registers:
continue
Expand All @@ -183,11 +190,18 @@ def _store_register_data(self) -> None:
register_value = self.drive.read(uid, subnode=axis)
except ILIOError:
continue
self._original_register_values[axis][uid] = register_value
register_values[(axis, uid)] = register_value
return register_values

def _store_objects_data(self) -> dict[CanOpenObject, bytes]:
"""Reads and returns complete access objects data.

def _store_objects_data(self) -> None:
Returns:
Dictionary mapping CanOpenObject to its byte value.
"""
object_values: dict[CanOpenObject, bytes] = {}
if not isinstance(self.drive, EthercatServo):
return
return object_values
for obj in self.drive.dictionary.all_objs():
uid = obj.uid
# Always read the rpdo/tpdo map objects using complete access
Expand All @@ -210,21 +224,38 @@ def _store_objects_data(self) -> None:
obj_value = self.drive.read_complete_access(obj)
except Exception:
continue
self._original_canopen_object_values[obj] = obj_value
object_values[obj] = obj_value
return object_values

def _restore_register_data(self) -> None:
"""Restores the drive values."""
def _restore_register_data(
self,
original_values: OrderedDict[tuple[int, str], Union[int, float, str, bytes]],
changed_values: OrderedDict[tuple[int, str], Union[int, float, str, bytes]],
force_restore: bool = False,
) -> None:
"""Restores the drive values.

Args:
original_values: OrderedDict mapping (axis, uid) to original value.
changed_values: OrderedDict mapping (axis, uid) to changed value.
force_restore: If True, registers are being restored by force mode.
"""
axes = list(self.drive.dictionary.subnodes) if self._axis is None else [self._axis]
restored_registers: dict[int, list[str]] = {axis: [] for axis in axes}

for (axis, uid), current_value in reversed(self._registers_changed.items()):
for (axis, uid), current_value in reversed(changed_values.items()):
# No original data for the register
if uid not in self._original_register_values[axis]:
if (axis, uid) not in original_values:
continue
# Register has already been restored with a newer value than the evaluated one
if uid in restored_registers[axis]:
continue
restore_value = self._original_register_values[axis][uid]
# Skip PDO mapping registers: handled via complete access in _restore_objects_data
if force_restore and (
_PDO_RPDO_MAP_REGISTER_UID in uid or _PDO_TPDO_MAP_REGISTER_UID in uid
):
continue
restore_value = original_values[(axis, uid)]
# No change with respect to the original value
if current_value == restore_value:
continue
Expand All @@ -240,27 +271,101 @@ def _restore_register_data(self) -> None:
self.drive.write(uid, restore_value, subnode=axis)
restored_registers[axis].append(uid)

def _restore_objects_data(self) -> None:
for obj in self._objects_changed:
def _restore_objects_data(
self,
original_values: dict[CanOpenObject, bytes],
changed_values: dict[CanOpenObject, bytes],
) -> None:
"""Restores complete access objects.

Args:
original_values: Dictionary mapping CanOpenObject to its original byte value.
changed_values: Dictionary mapping CanOpenObject to changed byte value.
"""
for obj, current_value in changed_values.items():
# https://novantamotion.atlassian.net/browse/DRIVSUS-137
if _MON_DATA_OBJECT_UID in obj.uid or _DIST_DATA_OBJECT_UID in obj.uid:
continue
restore_value = self._original_canopen_object_values.get(obj, None)
restore_value = original_values.get(obj)
if restore_value is None:
raise ValueError(f"No original data for the object {obj} to restore.")
logger.warning(
f"No original data for the object {obj} to restore. Skipping restoration."
)
continue

# If we have current_value, check if it differs
if current_value and current_value == restore_value:
continue

logger.debug(f"Restoring {obj} using complete access.")
self.drive.write_complete_access(obj, restore_value)

def __enter__(self) -> None:
"""Subscribes to register update callbacks and saves the drive values."""
self._store_register_data()
self._store_objects_data()
self._original_register_values = self._store_register_data()
self._original_canopen_object_values = self._store_objects_data()
self.drive.register_update_subscribe(self._register_update_callback)
self.drive.register_update_complete_access_subscribe(self._complete_access_callback)

def force_restore(self, restore_registers: bool = True, restore_objects: bool = True) -> None:
"""Force restoration of all registers to their original values.

This method re-reads all registers that were originally stored in __enter__,
compares them with the original values, and restores any that have changed.
It ignores the current state of _registers_changed and _objects_changed,
effectively performing a complete refresh and restoration.

This is useful when changes have been made outside the context manager's
tracking (e.g., external modifications to the drive).

Args:
restore_registers: If True, restores registers to their original values.
restore_objects: If True, restores complete access objects to their original values.
"""
if not restore_registers and not restore_objects:
return

# Temporarily unsubscribe from callbacks to avoid re-populating tracking during restoration
self.drive.register_update_unsubscribe(self._register_update_callback)
self.drive.register_update_complete_access_unsubscribe(self._complete_access_callback)

try:
if restore_registers:
# Clear the current tracking
self._registers_changed.clear()
# Re-read current register values and restore any differences
current_register_values = self._store_register_data()
self._restore_register_data(
original_values=self._original_register_values,
changed_values=current_register_values,
force_restore=True,
)

if restore_objects:
# Clear the current tracking
self._objects_changed.clear()

# Re-read current object values and restore any differences
current_object_values = self._store_objects_data()
self._restore_objects_data(
original_values=self._original_canopen_object_values,
changed_values=current_object_values,
)
finally:
# Re-subscribe to callbacks
self.drive.register_update_subscribe(self._register_update_callback)
self.drive.register_update_complete_access_subscribe(self._complete_access_callback)

def __exit__(self, exc_type, exc_value, traceback) -> None: # type: ignore [no-untyped-def]
"""Unsubscribes from register updates and restores the drive values."""
self.drive.register_update_unsubscribe(self._register_update_callback)
self.drive.register_update_complete_access_unsubscribe(self._complete_access_callback)
self._restore_register_data()
self._restore_objects_data()
self._restore_register_data(
original_values=self._original_register_values,
changed_values=self._registers_changed,
force_restore=False,
)
self._restore_objects_data(
original_values=self._original_canopen_object_values,
changed_values=self._objects_changed,
)
44 changes: 32 additions & 12 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ pytest-cov = "==2.12.1"
pytest-mock = "==3.6.1"
pytest-console-scripts = "==1.4.1"
twisted = "==24.11.0"
summit-testing-framework = {extras = ["ingeniamotion"], version = "==0.1.5+pr58b12"}
summit-testing-framework = {extras = ["ingeniamotion"], version = "==0.1.5+pr61b9"}

# -----------------------------------------------------------------------
# TASKS
Expand Down
Loading