diff --git a/bellows/ezsp/protocol.py b/bellows/ezsp/protocol.py index 1fe87ad1..f9eca74e 100644 --- a/bellows/ezsp/protocol.py +++ b/bellows/ezsp/protocol.py @@ -43,6 +43,9 @@ def __init__(self, cb_handler: Callable, gateway: Gateway) -> None: } self.tc_policy = 0 + # Cached by `set_extended_timeout` so subsequent calls are a little faster + self._address_table_size: int | None = None + def _ezsp_frame(self, name: str, *args: Any, **kwargs: Any) -> bytes: """Serialize the named frame and data.""" c, tx_schema, rx_schema = self.COMMANDS[name] @@ -252,3 +255,9 @@ async def read_counters(self) -> dict[t.EmberCounterType, int]: @abc.abstractmethod async def read_and_clear_counters(self) -> dict[t.EmberCounterType, int]: raise NotImplementedError + + @abc.abstractmethod + async def set_extended_timeout( + self, nwk: t.NWK, ieee: t.EUI64, extended_timeout: bool = True + ) -> None: + raise NotImplementedError() diff --git a/bellows/ezsp/v4/__init__.py b/bellows/ezsp/v4/__init__.py index b1e9fb6d..534c842a 100644 --- a/bellows/ezsp/v4/__init__.py +++ b/bellows/ezsp/v4/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations import logging +import random from typing import AsyncGenerator, Iterable import voluptuous as vol @@ -193,3 +194,44 @@ async def read_counters(self) -> dict[t.EmberCounterType, t.uint16_t]: async def read_and_clear_counters(self) -> dict[t.EmberCounterType, t.uint16_t]: (res,) = await self.readAndClearCounters() return dict(zip(t.EmberCounterType, res)) + + async def set_extended_timeout( + self, nwk: t.NWK, ieee: t.EUI64, extended_timeout: bool = True + ) -> None: + (curr_extended_timeout,) = await self.getExtendedTimeout(remoteEui64=ieee) + + if curr_extended_timeout == extended_timeout: + return + + (node_id,) = await self.lookupNodeIdByEui64(eui64=ieee) + + # Check to see if we have an address table entry + if node_id != 0xFFFF: + await self.setExtendedTimeout( + remoteEui64=ieee, extendedTimeout=extended_timeout + ) + return + + if self._address_table_size is None: + (status, addr_table_size) = await self.getConfigurationValue( + t.EzspConfigId.CONFIG_ADDRESS_TABLE_SIZE + ) + + if t.sl_Status.from_ember_status(status) != t.sl_Status.OK: + # Last-ditch effort + await self.setExtendedTimeout( + remoteEui64=ieee, extendedTimeout=extended_timeout + ) + return + + self._address_table_size = addr_table_size + + # Replace a random entry in the address table + index = random.randint(0, self._address_table_size - 1) + + await self.replaceAddressTableEntry( + addressTableIndex=index, + newEui64=ieee, + newId=nwk, + newExtendedTimeout=extended_timeout, + ) diff --git a/bellows/zigbee/application.py b/bellows/zigbee/application.py index 22247178..08a1169e 100644 --- a/bellows/zigbee/application.py +++ b/bellows/zigbee/application.py @@ -749,8 +749,10 @@ async def send_packet(self, packet: zigpy.types.ZigbeePacket) -> None: async with self._req_lock: if packet.dst.addr_mode == zigpy.types.AddrMode.NWK: if packet.extended_timeout and device is not None: - await self._ezsp.setExtendedTimeout( - remoteEui64=device.ieee, extendedTimeout=True + await self._ezsp.set_extended_timeout( + nwk=device.nwk, + ieee=device.ieee, + extended_timeout=True, ) if packet.source_route is not None: diff --git a/tests/test_application.py b/tests/test_application.py index c846f6bf..e35adfc5 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -178,6 +178,7 @@ def form_network(): ) proto.factory_reset = AsyncMock(proto=proto.factory_reset) + proto.set_extended_timeout = AsyncMock(proto=proto.set_extended_timeout) proto.read_link_keys = MagicMock() proto.read_link_keys.return_value.__aiter__.return_value = [ @@ -842,6 +843,19 @@ async def test_send_packet_unicast_source_route(make_app, packet): ) +async def test_send_packet_unicast_extended_timeout(app, ieee, packet): + app.add_device(nwk=packet.dst.address, ieee=ieee) + + await _test_send_packet_unicast( + app, + packet.replace(extended_timeout=True), + ) + + assert app._ezsp._protocol.set_extended_timeout.mock_calls == [ + call(nwk=packet.dst.address, ieee=ieee, extended_timeout=True) + ] + + @patch("bellows.zigbee.application.RETRY_DELAYS", [0.01, 0.01, 0.01]) async def test_send_packet_unicast_retries_success(app, packet): await _test_send_packet_unicast( diff --git a/tests/test_ezsp_v4.py b/tests/test_ezsp_v4.py index 69e00243..1fec0187 100644 --- a/tests/test_ezsp_v4.py +++ b/tests/test_ezsp_v4.py @@ -1,5 +1,5 @@ import logging -from unittest.mock import MagicMock, call +from unittest.mock import MagicMock, call, patch import pytest import zigpy.state @@ -379,3 +379,139 @@ async def test_read_counters(ezsp_f, length: int) -> None: ) assert counters1 == counters2 == {t.EmberCounterType(i): i for i in range(length)} + + +async def test_set_extended_timeout_no_entry(ezsp_f) -> None: + # Typical invocation + ezsp_f.getExtendedTimeout.return_value = (t.Bool.false,) + ezsp_f.lookupNodeIdByEui64.return_value = (0xFFFF,) # No address table entry + ezsp_f.getConfigurationValue.return_value = (t.EmberStatus.SUCCESS, 8) + ezsp_f.replaceAddressTableEntry.return_value = ( + t.EmberStatus.SUCCESS, + t.EUI64.convert("ff:ff:ff:ff:ff:ff:ff:ff"), + 0xFFFF, + t.Bool.false, + ) + + with patch("bellows.ezsp.v4.random.randint") as mock_random: + mock_random.return_value = 0 + await ezsp_f.set_extended_timeout( + nwk=0x1234, + ieee=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"), + extended_timeout=True, + ) + + assert ezsp_f.getExtendedTimeout.mock_calls == [ + call(remoteEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11")) + ] + assert ezsp_f.lookupNodeIdByEui64.mock_calls == [ + call(eui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11")) + ] + assert ezsp_f.getConfigurationValue.mock_calls == [ + call(t.EzspConfigId.CONFIG_ADDRESS_TABLE_SIZE) + ] + assert mock_random.mock_calls == [call(0, 8 - 1)] + assert ezsp_f.replaceAddressTableEntry.mock_calls == [ + call( + addressTableIndex=0, + newEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"), + newId=0x1234, + newExtendedTimeout=True, + ) + ] + + # The address table size is cached + with patch("bellows.ezsp.v4.random.randint") as mock_random: + mock_random.return_value = 1 + await ezsp_f.set_extended_timeout( + nwk=0x1234, + ieee=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"), + extended_timeout=True, + ) + + # Still called only once + assert ezsp_f.getConfigurationValue.mock_calls == [ + call(t.EzspConfigId.CONFIG_ADDRESS_TABLE_SIZE) + ] + + assert ezsp_f.replaceAddressTableEntry.mock_calls == [ + call( + addressTableIndex=0, + newEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"), + newId=0x1234, + newExtendedTimeout=True, + ), + call( + addressTableIndex=1, + newEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"), + newId=0x1234, + newExtendedTimeout=True, + ), + ] + + +async def test_set_extended_timeout_already_set(ezsp_f) -> None: + # No-op, it's already set + ezsp_f.setExtendedTimeout.return_value = () + ezsp_f.getExtendedTimeout.return_value = (t.Bool.true,) + + await ezsp_f.set_extended_timeout( + nwk=0x1234, + ieee=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"), + extended_timeout=True, + ) + + assert ezsp_f.getExtendedTimeout.mock_calls == [ + call(remoteEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11")) + ] + assert ezsp_f.setExtendedTimeout.mock_calls == [] + + +async def test_set_extended_timeout_already_have_entry(ezsp_f) -> None: + # An address table entry is present + ezsp_f.setExtendedTimeout.return_value = () + ezsp_f.getExtendedTimeout.return_value = (t.Bool.false,) + ezsp_f.lookupNodeIdByEui64.return_value = (0x1234,) + + await ezsp_f.set_extended_timeout( + nwk=0x1234, + ieee=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"), + extended_timeout=True, + ) + + assert ezsp_f.getExtendedTimeout.mock_calls == [ + call(remoteEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11")) + ] + assert ezsp_f.lookupNodeIdByEui64.mock_calls == [ + call(eui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11")) + ] + assert ezsp_f.setExtendedTimeout.mock_calls == [ + call( + remoteEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"), extendedTimeout=True + ) + ] + + +async def test_set_extended_timeout_bad_table_size(ezsp_f) -> None: + ezsp_f.setExtendedTimeout.return_value = () + ezsp_f.getExtendedTimeout.return_value = (t.Bool.false,) + ezsp_f.lookupNodeIdByEui64.return_value = (0xFFFF,) + ezsp_f.getConfigurationValue.return_value = (t.EmberStatus.ERR_FATAL, 0xFF) + + with patch("bellows.ezsp.v4.random.randint") as mock_random: + mock_random.return_value = 0 + await ezsp_f.set_extended_timeout( + nwk=0x1234, + ieee=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"), + extended_timeout=True, + ) + + assert ezsp_f.getExtendedTimeout.mock_calls == [ + call(remoteEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11")) + ] + assert ezsp_f.lookupNodeIdByEui64.mock_calls == [ + call(eui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11")) + ] + assert ezsp_f.getConfigurationValue.mock_calls == [ + call(t.EzspConfigId.CONFIG_ADDRESS_TABLE_SIZE) + ]