Skip to content

Commit

Permalink
Improve transmission reliability for sleepy end-devices (#646)
Browse files Browse the repository at this point in the history
* Implement `set_extended_timeout` in the protocol handler

* Add tests

* Add test for `send_packet` as well

* Cache the address table size once it's read
  • Loading branch information
puddly authored Aug 23, 2024
1 parent ccb3afd commit e160be2
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 3 deletions.
9 changes: 9 additions & 0 deletions bellows/ezsp/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def __init__(self, cb_handler: Callable, gateway: Gateway) -> None:
}
self.tc_policy = 0

# Cached by `set_extended_timeout` so subsequent calls are a little faster
self._address_table_size: int | None = None

def _ezsp_frame(self, name: str, *args: Any, **kwargs: Any) -> bytes:
"""Serialize the named frame and data."""
c, tx_schema, rx_schema = self.COMMANDS[name]
Expand Down Expand Up @@ -252,3 +255,9 @@ async def read_counters(self) -> dict[t.EmberCounterType, int]:
@abc.abstractmethod
async def read_and_clear_counters(self) -> dict[t.EmberCounterType, int]:
raise NotImplementedError

@abc.abstractmethod
async def set_extended_timeout(
self, nwk: t.NWK, ieee: t.EUI64, extended_timeout: bool = True
) -> None:
raise NotImplementedError()
42 changes: 42 additions & 0 deletions bellows/ezsp/v4/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import logging
import random
from typing import AsyncGenerator, Iterable

import voluptuous as vol
Expand Down Expand Up @@ -193,3 +194,44 @@ async def read_counters(self) -> dict[t.EmberCounterType, t.uint16_t]:
async def read_and_clear_counters(self) -> dict[t.EmberCounterType, t.uint16_t]:
(res,) = await self.readAndClearCounters()
return dict(zip(t.EmberCounterType, res))

async def set_extended_timeout(
self, nwk: t.NWK, ieee: t.EUI64, extended_timeout: bool = True
) -> None:
(curr_extended_timeout,) = await self.getExtendedTimeout(remoteEui64=ieee)

if curr_extended_timeout == extended_timeout:
return

(node_id,) = await self.lookupNodeIdByEui64(eui64=ieee)

# Check to see if we have an address table entry
if node_id != 0xFFFF:
await self.setExtendedTimeout(
remoteEui64=ieee, extendedTimeout=extended_timeout
)
return

if self._address_table_size is None:
(status, addr_table_size) = await self.getConfigurationValue(
t.EzspConfigId.CONFIG_ADDRESS_TABLE_SIZE
)

if t.sl_Status.from_ember_status(status) != t.sl_Status.OK:
# Last-ditch effort
await self.setExtendedTimeout(
remoteEui64=ieee, extendedTimeout=extended_timeout
)
return

self._address_table_size = addr_table_size

# Replace a random entry in the address table
index = random.randint(0, self._address_table_size - 1)

await self.replaceAddressTableEntry(
addressTableIndex=index,
newEui64=ieee,
newId=nwk,
newExtendedTimeout=extended_timeout,
)
6 changes: 4 additions & 2 deletions bellows/zigbee/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,8 +749,10 @@ async def send_packet(self, packet: zigpy.types.ZigbeePacket) -> None:
async with self._req_lock:
if packet.dst.addr_mode == zigpy.types.AddrMode.NWK:
if packet.extended_timeout and device is not None:
await self._ezsp.setExtendedTimeout(
remoteEui64=device.ieee, extendedTimeout=True
await self._ezsp.set_extended_timeout(
nwk=device.nwk,
ieee=device.ieee,
extended_timeout=True,
)

if packet.source_route is not None:
Expand Down
14 changes: 14 additions & 0 deletions tests/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def form_network():
)

proto.factory_reset = AsyncMock(proto=proto.factory_reset)
proto.set_extended_timeout = AsyncMock(proto=proto.set_extended_timeout)

proto.read_link_keys = MagicMock()
proto.read_link_keys.return_value.__aiter__.return_value = [
Expand Down Expand Up @@ -842,6 +843,19 @@ async def test_send_packet_unicast_source_route(make_app, packet):
)


async def test_send_packet_unicast_extended_timeout(app, ieee, packet):
app.add_device(nwk=packet.dst.address, ieee=ieee)

await _test_send_packet_unicast(
app,
packet.replace(extended_timeout=True),
)

assert app._ezsp._protocol.set_extended_timeout.mock_calls == [
call(nwk=packet.dst.address, ieee=ieee, extended_timeout=True)
]


@patch("bellows.zigbee.application.RETRY_DELAYS", [0.01, 0.01, 0.01])
async def test_send_packet_unicast_retries_success(app, packet):
await _test_send_packet_unicast(
Expand Down
138 changes: 137 additions & 1 deletion tests/test_ezsp_v4.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from unittest.mock import MagicMock, call
from unittest.mock import MagicMock, call, patch

import pytest
import zigpy.state
Expand Down Expand Up @@ -379,3 +379,139 @@ async def test_read_counters(ezsp_f, length: int) -> None:
)

assert counters1 == counters2 == {t.EmberCounterType(i): i for i in range(length)}


async def test_set_extended_timeout_no_entry(ezsp_f) -> None:
# Typical invocation
ezsp_f.getExtendedTimeout.return_value = (t.Bool.false,)
ezsp_f.lookupNodeIdByEui64.return_value = (0xFFFF,) # No address table entry
ezsp_f.getConfigurationValue.return_value = (t.EmberStatus.SUCCESS, 8)
ezsp_f.replaceAddressTableEntry.return_value = (
t.EmberStatus.SUCCESS,
t.EUI64.convert("ff:ff:ff:ff:ff:ff:ff:ff"),
0xFFFF,
t.Bool.false,
)

with patch("bellows.ezsp.v4.random.randint") as mock_random:
mock_random.return_value = 0
await ezsp_f.set_extended_timeout(
nwk=0x1234,
ieee=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"),
extended_timeout=True,
)

assert ezsp_f.getExtendedTimeout.mock_calls == [
call(remoteEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"))
]
assert ezsp_f.lookupNodeIdByEui64.mock_calls == [
call(eui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"))
]
assert ezsp_f.getConfigurationValue.mock_calls == [
call(t.EzspConfigId.CONFIG_ADDRESS_TABLE_SIZE)
]
assert mock_random.mock_calls == [call(0, 8 - 1)]
assert ezsp_f.replaceAddressTableEntry.mock_calls == [
call(
addressTableIndex=0,
newEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"),
newId=0x1234,
newExtendedTimeout=True,
)
]

# The address table size is cached
with patch("bellows.ezsp.v4.random.randint") as mock_random:
mock_random.return_value = 1
await ezsp_f.set_extended_timeout(
nwk=0x1234,
ieee=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"),
extended_timeout=True,
)

# Still called only once
assert ezsp_f.getConfigurationValue.mock_calls == [
call(t.EzspConfigId.CONFIG_ADDRESS_TABLE_SIZE)
]

assert ezsp_f.replaceAddressTableEntry.mock_calls == [
call(
addressTableIndex=0,
newEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"),
newId=0x1234,
newExtendedTimeout=True,
),
call(
addressTableIndex=1,
newEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"),
newId=0x1234,
newExtendedTimeout=True,
),
]


async def test_set_extended_timeout_already_set(ezsp_f) -> None:
# No-op, it's already set
ezsp_f.setExtendedTimeout.return_value = ()
ezsp_f.getExtendedTimeout.return_value = (t.Bool.true,)

await ezsp_f.set_extended_timeout(
nwk=0x1234,
ieee=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"),
extended_timeout=True,
)

assert ezsp_f.getExtendedTimeout.mock_calls == [
call(remoteEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"))
]
assert ezsp_f.setExtendedTimeout.mock_calls == []


async def test_set_extended_timeout_already_have_entry(ezsp_f) -> None:
# An address table entry is present
ezsp_f.setExtendedTimeout.return_value = ()
ezsp_f.getExtendedTimeout.return_value = (t.Bool.false,)
ezsp_f.lookupNodeIdByEui64.return_value = (0x1234,)

await ezsp_f.set_extended_timeout(
nwk=0x1234,
ieee=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"),
extended_timeout=True,
)

assert ezsp_f.getExtendedTimeout.mock_calls == [
call(remoteEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"))
]
assert ezsp_f.lookupNodeIdByEui64.mock_calls == [
call(eui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"))
]
assert ezsp_f.setExtendedTimeout.mock_calls == [
call(
remoteEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"), extendedTimeout=True
)
]


async def test_set_extended_timeout_bad_table_size(ezsp_f) -> None:
ezsp_f.setExtendedTimeout.return_value = ()
ezsp_f.getExtendedTimeout.return_value = (t.Bool.false,)
ezsp_f.lookupNodeIdByEui64.return_value = (0xFFFF,)
ezsp_f.getConfigurationValue.return_value = (t.EmberStatus.ERR_FATAL, 0xFF)

with patch("bellows.ezsp.v4.random.randint") as mock_random:
mock_random.return_value = 0
await ezsp_f.set_extended_timeout(
nwk=0x1234,
ieee=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"),
extended_timeout=True,
)

assert ezsp_f.getExtendedTimeout.mock_calls == [
call(remoteEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"))
]
assert ezsp_f.lookupNodeIdByEui64.mock_calls == [
call(eui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"))
]
assert ezsp_f.getConfigurationValue.mock_calls == [
call(t.EzspConfigId.CONFIG_ADDRESS_TABLE_SIZE)
]

0 comments on commit e160be2

Please sign in to comment.