diff --git a/ingenialink/drive_context_manager.py b/ingenialink/drive_context_manager.py index a87dcdbe..7116f283 100644 --- a/ingenialink/drive_context_manager.py +++ b/ingenialink/drive_context_manager.py @@ -75,14 +75,16 @@ def __init__( set(complete_access_objects) if isinstance(complete_access_objects, list) else set() ) - self._original_register_values: dict[int, dict[str, Union[int, float, str, bytes]]] = {} + self._original_register_values: OrderedDict[ + tuple[int, str], Union[int, float, str, bytes] + ] = OrderedDict() self._original_canopen_object_values: dict[CanOpenObject, bytes] = {} # Key: (axis, uid), value self._registers_changed = OrderedDict[tuple[int, str], Union[int, float, str, bytes]]() - self._objects_changed = set[CanOpenObject]() + self._objects_changed: dict[CanOpenObject, bytes] = {} def _register_update_callback( self, @@ -104,14 +106,14 @@ def _register_update_callback( return if uid in self._do_not_restore_registers: return - if uid not in self._original_register_values[register.subnode]: + dict_key = (register.subnode, uid) + if dict_key not in self._original_register_values: return # Check if the new value is different from the previous one - dict_key = (register.subnode, uid) if dict_key in self._registers_changed: previous_value = self._registers_changed[dict_key] - previous_value = self._original_register_values[register.subnode][uid] + previous_value = self._original_register_values[dict_key] current_value = value if value is not None else previous_value if current_value == previous_value: return @@ -155,15 +157,20 @@ def _complete_access_callback( self._register_update_callback(servo=servo, register=register, value=value) return - self._objects_changed.add(obj) + # Store the object as changed (actual value will be determined during restoration) + self._objects_changed[obj] = b"" # Placeholder, actual restore uses original value logger.debug(f"{id(self)}: Object {obj.uid} changed using complete access to {value!r}.") - def _store_register_data(self) -> None: - """Saves the value of all registers.""" + def _store_register_data(self) -> OrderedDict[tuple[int, str], Union[int, float, str, bytes]]: + """Reads and returns the value of all registers. + + Returns: + OrderedDict mapping (axis, uid) tuple to register value. + """ + register_values: OrderedDict[tuple[int, str], Union[int, float, str, bytes]] = OrderedDict() axes = list(self.drive.dictionary.subnodes) if self._axis is None else [self._axis] for axis in axes: - self._original_register_values[axis] = {} for uid, register in self.drive.dictionary.registers(subnode=axis).items(): if uid in self._do_not_restore_registers: continue @@ -183,11 +190,18 @@ def _store_register_data(self) -> None: register_value = self.drive.read(uid, subnode=axis) except ILIOError: continue - self._original_register_values[axis][uid] = register_value + register_values[(axis, uid)] = register_value + return register_values + + def _store_objects_data(self) -> dict[CanOpenObject, bytes]: + """Reads and returns complete access objects data. - def _store_objects_data(self) -> None: + Returns: + Dictionary mapping CanOpenObject to its byte value. + """ + object_values: dict[CanOpenObject, bytes] = {} if not isinstance(self.drive, EthercatServo): - return + return object_values for obj in self.drive.dictionary.all_objs(): uid = obj.uid # Always read the rpdo/tpdo map objects using complete access @@ -210,21 +224,38 @@ def _store_objects_data(self) -> None: obj_value = self.drive.read_complete_access(obj) except Exception: continue - self._original_canopen_object_values[obj] = obj_value + object_values[obj] = obj_value + return object_values - def _restore_register_data(self) -> None: - """Restores the drive values.""" + def _restore_register_data( + self, + original_values: OrderedDict[tuple[int, str], Union[int, float, str, bytes]], + changed_values: OrderedDict[tuple[int, str], Union[int, float, str, bytes]], + force_restore: bool = False, + ) -> None: + """Restores the drive values. + + Args: + original_values: OrderedDict mapping (axis, uid) to original value. + changed_values: OrderedDict mapping (axis, uid) to changed value. + force_restore: If True, registers are being restored by force mode. + """ axes = list(self.drive.dictionary.subnodes) if self._axis is None else [self._axis] restored_registers: dict[int, list[str]] = {axis: [] for axis in axes} - for (axis, uid), current_value in reversed(self._registers_changed.items()): + for (axis, uid), current_value in reversed(changed_values.items()): # No original data for the register - if uid not in self._original_register_values[axis]: + if (axis, uid) not in original_values: continue # Register has already been restored with a newer value than the evaluated one if uid in restored_registers[axis]: continue - restore_value = self._original_register_values[axis][uid] + # Skip PDO mapping registers: handled via complete access in _restore_objects_data + if force_restore and ( + _PDO_RPDO_MAP_REGISTER_UID in uid or _PDO_TPDO_MAP_REGISTER_UID in uid + ): + continue + restore_value = original_values[(axis, uid)] # No change with respect to the original value if current_value == restore_value: continue @@ -240,27 +271,101 @@ def _restore_register_data(self) -> None: self.drive.write(uid, restore_value, subnode=axis) restored_registers[axis].append(uid) - def _restore_objects_data(self) -> None: - for obj in self._objects_changed: + def _restore_objects_data( + self, + original_values: dict[CanOpenObject, bytes], + changed_values: dict[CanOpenObject, bytes], + ) -> None: + """Restores complete access objects. + + Args: + original_values: Dictionary mapping CanOpenObject to its original byte value. + changed_values: Dictionary mapping CanOpenObject to changed byte value. + """ + for obj, current_value in changed_values.items(): # https://novantamotion.atlassian.net/browse/DRIVSUS-137 if _MON_DATA_OBJECT_UID in obj.uid or _DIST_DATA_OBJECT_UID in obj.uid: continue - restore_value = self._original_canopen_object_values.get(obj, None) + restore_value = original_values.get(obj) if restore_value is None: - raise ValueError(f"No original data for the object {obj} to restore.") + logger.warning( + f"No original data for the object {obj} to restore. Skipping restoration." + ) + continue + + # If we have current_value, check if it differs + if current_value and current_value == restore_value: + continue + logger.debug(f"Restoring {obj} using complete access.") self.drive.write_complete_access(obj, restore_value) def __enter__(self) -> None: """Subscribes to register update callbacks and saves the drive values.""" - self._store_register_data() - self._store_objects_data() + self._original_register_values = self._store_register_data() + self._original_canopen_object_values = self._store_objects_data() self.drive.register_update_subscribe(self._register_update_callback) self.drive.register_update_complete_access_subscribe(self._complete_access_callback) + def force_restore(self, restore_registers: bool = True, restore_objects: bool = True) -> None: + """Force restoration of all registers to their original values. + + This method re-reads all registers that were originally stored in __enter__, + compares them with the original values, and restores any that have changed. + It ignores the current state of _registers_changed and _objects_changed, + effectively performing a complete refresh and restoration. + + This is useful when changes have been made outside the context manager's + tracking (e.g., external modifications to the drive). + + Args: + restore_registers: If True, restores registers to their original values. + restore_objects: If True, restores complete access objects to their original values. + """ + if not restore_registers and not restore_objects: + return + + # Temporarily unsubscribe from callbacks to avoid re-populating tracking during restoration + self.drive.register_update_unsubscribe(self._register_update_callback) + self.drive.register_update_complete_access_unsubscribe(self._complete_access_callback) + + try: + if restore_registers: + # Clear the current tracking + self._registers_changed.clear() + # Re-read current register values and restore any differences + current_register_values = self._store_register_data() + self._restore_register_data( + original_values=self._original_register_values, + changed_values=current_register_values, + force_restore=True, + ) + + if restore_objects: + # Clear the current tracking + self._objects_changed.clear() + + # Re-read current object values and restore any differences + current_object_values = self._store_objects_data() + self._restore_objects_data( + original_values=self._original_canopen_object_values, + changed_values=current_object_values, + ) + finally: + # Re-subscribe to callbacks + self.drive.register_update_subscribe(self._register_update_callback) + self.drive.register_update_complete_access_subscribe(self._complete_access_callback) + def __exit__(self, exc_type, exc_value, traceback) -> None: # type: ignore [no-untyped-def] """Unsubscribes from register updates and restores the drive values.""" self.drive.register_update_unsubscribe(self._register_update_callback) self.drive.register_update_complete_access_unsubscribe(self._complete_access_callback) - self._restore_register_data() - self._restore_objects_data() + self._restore_register_data( + original_values=self._original_register_values, + changed_values=self._registers_changed, + force_restore=False, + ) + self._restore_objects_data( + original_values=self._original_canopen_object_values, + changed_values=self._objects_changed, + ) diff --git a/poetry.lock b/poetry.lock index edb35ed0..45cda255 100644 --- a/poetry.lock +++ b/poetry.lock @@ -930,11 +930,11 @@ description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" groups = ["dev", "tests"] -markers = "python_version < \"3.11\"" files = [ {file = "exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10"}, {file = "exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88"}, ] +markers = {dev = "python_version < \"3.11\""} [package.dependencies] typing-extensions = {version = ">=4.6.0", markers = "python_version < \"3.13\""} @@ -992,6 +992,24 @@ files = [ [package.dependencies] packaging = ">=20" +[[package]] +name = "fsoe-master" +version = "0.2.1" +description = "FSoE Master Library" +optional = false +python-versions = ">=3.9" +groups = ["tests"] +files = [ + {file = "fsoe_master-0.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7d9c97d9c6805398ded85a6adc9acde91eb4355cc006712340104e44e9a4090"}, + {file = "fsoe_master-0.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:24f290176a4763a11de45abea7c01d606dd88717cec1100eb0f2c2a3982acc24"}, + {file = "fsoe_master-0.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7c2a1b80a430d29f319ef2b012499abd7f4d6516f05c89282ebd7fb12029f86"}, + {file = "fsoe_master-0.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:724368ff950db631ca66db8fb912cc1e34c047dc0d978c3a00a0d5fbe5db4a03"}, + {file = "fsoe_master-0.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a933de618bde562b6b868e537f35ac9c2de2168333e93f9fc0145d18b3bbf02"}, + {file = "fsoe_master-0.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:cff774613524f03dfdc3842c3f307d84dae8cf615fc69d45db1971278373d1aa"}, + {file = "fsoe_master-0.2.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b022871aed62743f22a88a4152ad8ded50497299c113b546c8dec4d1285edf60"}, + {file = "fsoe_master-0.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:91a30e23883d7edd15dedd1a4c2ecb18c508bc463eb8c54f0b9cd9bb012e6b6d"}, +] + [[package]] name = "h11" version = "0.16.0" @@ -1207,23 +1225,25 @@ files = [ [[package]] name = "ingeniamotion" -version = "0.10.0" +version = "0.10.1" description = "Motion library for Novanta servo drives" optional = false python-versions = ">=3.9" groups = ["tests"] files = [ - {file = "ingeniamotion-0.10.0-py3-none-any.whl", hash = "sha256:24facebd529cd42484fab19624da156a453a598f826ef43149076929b766224c"}, - {file = "ingeniamotion-0.10.0.tar.gz", hash = "sha256:f1428eb8e5f204fd3426934422e8e42e30f46d074bfa387319d428848057f348"}, + {file = "ingeniamotion-0.10.1-py3-none-any.whl", hash = "sha256:67ff43e2ec8786aa266570d65b212b4d387148400520914eabab536cae131df7"}, + {file = "ingeniamotion-0.10.1.tar.gz", hash = "sha256:5b29ba65f75b1fcffe9bd6c1236f00394209b7c06c4b2d2024e60ecca8e5b5c3"}, ] [package.dependencies] +exceptiongroup = ">=1.3.0,<2.0.0" +fsoe_master = {version = "0.2.1", optional = true, markers = "extra == \"fsoe\""} ifaddr = "0.1.7" -ingenialink = ">=7.5.1,<8.0.0" +ingenialink = ">=7.5.2,<8.0.0" ingenialogger = ">=0.2.1" [package.extras] -fsoe = ["fsoe_master (==0.2.0)"] +fsoe = ["fsoe_master (==0.2.1)"] [[package]] name = "iniconfig" @@ -3624,25 +3644,25 @@ test = ["pytest"] [[package]] name = "summit-testing-framework" -version = "0.1.5+pr58b12" +version = "0.1.5+pr61b9" description = "Testing framework for Novanta drives" optional = false python-versions = ">=3.9" groups = ["tests"] files = [ - {file = "summit_testing_framework-0.1.5+pr58b12-py3-none-any.whl", hash = "sha256:8e3d765a1440b6dde36fdad79052ba5f134241795db462de512e00eaa99dba1b"}, + {file = "summit_testing_framework-0.1.5+pr61b9-py3-none-any.whl", hash = "sha256:a0541b305947da93ea75c88aeb95bbebfe262222ad1b96a7ac7d183d8e91558f"}, ] [package.dependencies] ingenia-att-api = "2.3.1" -ingeniamotion = {version = ">=0.9.2", optional = true, markers = "extra == \"ingeniamotion\""} -pytest = ">=8.4.1" +ingeniamotion = {version = ">=0.10.1", extras = ["fsoe"], optional = true, markers = "extra == \"ingeniamotion\""} +pytest = ">=8.4.1,<9.0.0" pytest-env = ">=1.1.5" rpyc = "6.0.0" [package.extras] ingenialink = ["ingenialink (>=7.5.2)"] -ingeniamotion = ["ingeniamotion (>=0.9.2)"] +ingeniamotion = ["ingeniamotion[fsoe] (>=0.10.1)"] [package.source] type = "legacy" @@ -4312,4 +4332,4 @@ virtual-drive = ["virtual_drive"] [metadata] lock-version = "2.1" python-versions = ">=3.9" -content-hash = "0b5f489a3b75170a1bc2f85aa7e8349b4315fa7a8882ab437eea19e2ab93ac8f" +content-hash = "442c1625892f0952ccec732519a22431a5386b38958e1c6c20cdfca4212553bd" diff --git a/pyproject.toml b/pyproject.toml index 1957fad3..ce9e465a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,7 +118,7 @@ pytest-cov = "==2.12.1" pytest-mock = "==3.6.1" pytest-console-scripts = "==1.4.1" twisted = "==24.11.0" -summit-testing-framework = {extras = ["ingeniamotion"], version = "==0.1.5+pr58b12"} +summit-testing-framework = {extras = ["ingeniamotion"], version = "==0.1.5+pr61b9"} # ----------------------------------------------------------------------- # TASKS diff --git a/tests/test_drive_context_manager.py b/tests/test_drive_context_manager.py index da6e2121..cac4c81d 100644 --- a/tests/test_drive_context_manager.py +++ b/tests/test_drive_context_manager.py @@ -168,7 +168,7 @@ def test_drive_context_manager_restores_complete_access_registers( with context: assert context._registers_changed == {} - assert context._objects_changed == set() + assert context._objects_changed == {} servo.set_pdo_map_to_slave([rpdo_map], [tpdo_map]) servo.map_pdos(slave_index=setup_descriptor.slave) @@ -180,3 +180,200 @@ def test_drive_context_manager_restores_complete_access_registers( assert "ETG_COMMS_TPDO_ASSIGN" in objects_uids assert "ETG_COMMS_RPDO_MAP1" in objects_uids assert "ETG_COMMS_TPDO_MAP1" in objects_uids + + +@pytest.mark.ethernet +@pytest.mark.ethercat +@pytest.mark.canopen +@pytest.mark.virtual +def test_force_restore_with_external_changes( + setup_manager: tuple["Network", Union[str, list[str]], "DriveEnvironmentController"], +) -> None: + """Test that force_restore detects and restores changes made outside the context manager.""" + net, _, _ = setup_manager + servo = net.servos[0] + context = DriveContextManager(servo) + + new_over_volt_value = 100.0 + previous_over_volt_value = _read_user_over_voltage_uid(servo) + if previous_over_volt_value == new_over_volt_value: + new_over_volt_value -= 1.0 + + new_under_volt_value = 1.0 + previous_under_volt_value = _read_user_under_voltage_uid(servo) + if previous_under_volt_value == new_under_volt_value: + new_under_volt_value += 1.0 + + with context: + # Make a tracked change + servo.write(_USER_OVER_VOLTAGE_UID, new_over_volt_value, subnode=1) + assert _read_user_over_voltage_uid(servo) == new_over_volt_value + assert (1, _USER_OVER_VOLTAGE_UID) in context._registers_changed + + # Simulate an external change (bypass the callback by directly modifying the drive) + # In reality, this would be done by another process/connection + servo.write(_USER_UNDER_VOLTAGE_UID, new_under_volt_value, subnode=1) + + # Clear the tracking to simulate that this change wasn't tracked + context._registers_changed.pop((1, _USER_UNDER_VOLTAGE_UID), None) + + # Verify the external change is present + assert _read_user_under_voltage_uid(servo) == new_under_volt_value + + # Force restore should detect and restore both changes + context.force_restore() + + # Both registers should now be back to original values + assert _read_user_over_voltage_uid(servo) == previous_over_volt_value + assert _read_user_under_voltage_uid(servo) == previous_under_volt_value + + # Tracking should be cleared + assert context._registers_changed == {} + assert context._objects_changed == {} + + +@pytest.mark.ethernet +@pytest.mark.ethercat +@pytest.mark.canopen +@pytest.mark.virtual +def test_force_restore_clears_tracking( + setup_manager: tuple["Network", Union[str, list[str]], "DriveEnvironmentController"], +) -> None: + """Test that force_restore clears the internal tracking dictionaries.""" + net, _, _ = setup_manager + servo = net.servos[0] + context = DriveContextManager(servo) + + new_reg_value = 100.0 + previous_reg_value = _read_user_over_voltage_uid(servo) + if previous_reg_value == new_reg_value: + new_reg_value -= 1.0 + + with context: + # Make changes + servo.write(_USER_OVER_VOLTAGE_UID, new_reg_value, subnode=1) + assert (1, _USER_OVER_VOLTAGE_UID) in context._registers_changed + + # Force restore + context.force_restore() + + # Verify tracking is cleared + assert context._registers_changed == {} + assert context._objects_changed == {} + + # Verify register was restored + assert _read_user_over_voltage_uid(servo) == previous_reg_value + + +@pytest.mark.ethernet +@pytest.mark.ethercat +@pytest.mark.canopen +@pytest.mark.virtual +def test_force_restore_only_restores_changed_values( + setup_manager: tuple["Network", Union[str, list[str]], "DriveEnvironmentController"], +) -> None: + """Test that force_restore only restores registers that have actually changed.""" + net, _, _ = setup_manager + servo = net.servos[0] + context = DriveContextManager(servo) + + new_reg_value = 100.0 + previous_over_volt_value = _read_user_over_voltage_uid(servo) + if previous_over_volt_value == new_reg_value: + new_reg_value -= 1.0 + + previous_under_volt_value = _read_user_under_voltage_uid(servo) + + with context: + # Change only one register + servo.write(_USER_OVER_VOLTAGE_UID, new_reg_value, subnode=1) + assert _read_user_over_voltage_uid(servo) == new_reg_value + + # The other register should still have its original value + assert _read_user_under_voltage_uid(servo) == previous_under_volt_value + + # Force restore should only restore the changed register + context.force_restore() + + # Changed register should be restored + assert _read_user_over_voltage_uid(servo) == previous_over_volt_value + # Unchanged register should still have original value + assert _read_user_under_voltage_uid(servo) == previous_under_volt_value + + +@pytest.mark.ethernet +@pytest.mark.ethercat +@pytest.mark.canopen +@pytest.mark.virtual +def test_force_restore_multiple_times( + setup_manager: tuple["Network", Union[str, list[str]], "DriveEnvironmentController"], +) -> None: + """Test that force_restore can be called multiple times.""" + net, _, _ = setup_manager + servo = net.servos[0] + context = DriveContextManager(servo) + + new_reg_value = 100.0 + previous_reg_value = _read_user_over_voltage_uid(servo) + if previous_reg_value == new_reg_value: + new_reg_value -= 1.0 + + with context: + # Make a change + servo.write(_USER_OVER_VOLTAGE_UID, new_reg_value, subnode=1) + assert _read_user_over_voltage_uid(servo) == new_reg_value + + # First force restore + context.force_restore() + assert _read_user_over_voltage_uid(servo) == previous_reg_value + + # Make another change + new_reg_value_2 = new_reg_value - 10 + if previous_reg_value == new_reg_value_2: + new_reg_value_2 -= 1.0 + servo.write(_USER_OVER_VOLTAGE_UID, new_reg_value_2, subnode=1) + assert _read_user_over_voltage_uid(servo) == new_reg_value_2 + + # Second force restore + context.force_restore() + assert _read_user_over_voltage_uid(servo) == previous_reg_value + + +@pytest.mark.ethercat +def test_force_restore_with_complete_access_objects( + setup_manager: tuple["EthercatNetwork", str, "DriveEnvironmentController"], + setup_descriptor: "DriveEcatSetup", +) -> None: + """Test that force_restore works with complete access objects (PDO mappings).""" + net, _, _ = setup_manager + servo = net.servos[0] + context = DriveContextManager(servo) + + # Store original PDO state + servo.reset_rpdo_mapping() + servo.reset_tpdo_mapping() + + tpdo_map = TPDOMap() + tpdo_registers = ["CL_POS_FBK_VALUE", "CL_VEL_FBK_VALUE"] + for tpdo_register in tpdo_registers: + register = servo.dictionary.get_register(tpdo_register) + tpdo_map.add_registers(register) + + rpdo_map = RPDOMap() + rpdo_registers = ["CL_POS_SET_POINT_VALUE", "CL_VEL_SET_POINT_VALUE"] + for rpdo_register in rpdo_registers: + register = servo.dictionary.get_register(rpdo_register) + rpdo_map.add_registers(register) + + with context: + # Change PDO mappings + servo.set_pdo_map_to_slave([rpdo_map], [tpdo_map]) + servo.map_pdos(slave_index=setup_descriptor.slave) + + assert len(context._objects_changed) > 0 + + # Force restore should restore PDO mappings to original state + context.force_restore() + + # Tracking should be cleared + assert context._objects_changed == {}