Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve transmission reliability for sleepy end-devices #646

Merged
merged 5 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions bellows/ezsp/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def __init__(self, cb_handler: Callable, gateway: Gateway) -> None:
}
self.tc_policy = 0

# Cached by `set_extended_timeout` so subsequent calls are a little faster
self._address_table_size: int | None = None

def _ezsp_frame(self, name: str, *args: Any, **kwargs: Any) -> bytes:
"""Serialize the named frame and data."""
c, tx_schema, rx_schema = self.COMMANDS[name]
Expand Down Expand Up @@ -252,3 +255,9 @@ async def read_counters(self) -> dict[t.EmberCounterType, int]:
@abc.abstractmethod
async def read_and_clear_counters(self) -> dict[t.EmberCounterType, int]:
raise NotImplementedError

@abc.abstractmethod
async def set_extended_timeout(
self, nwk: t.NWK, ieee: t.EUI64, extended_timeout: bool = True
) -> None:
raise NotImplementedError()
42 changes: 42 additions & 0 deletions bellows/ezsp/v4/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import logging
import random
from typing import AsyncGenerator, Iterable

import voluptuous as vol
Expand Down Expand Up @@ -193,3 +194,44 @@ async def read_counters(self) -> dict[t.EmberCounterType, t.uint16_t]:
async def read_and_clear_counters(self) -> dict[t.EmberCounterType, t.uint16_t]:
(res,) = await self.readAndClearCounters()
return dict(zip(t.EmberCounterType, res))

async def set_extended_timeout(
self, nwk: t.NWK, ieee: t.EUI64, extended_timeout: bool = True
) -> None:
(curr_extended_timeout,) = await self.getExtendedTimeout(remoteEui64=ieee)

if curr_extended_timeout == extended_timeout:
return

(node_id,) = await self.lookupNodeIdByEui64(eui64=ieee)

# Check to see if we have an address table entry
if node_id != 0xFFFF:
await self.setExtendedTimeout(
remoteEui64=ieee, extendedTimeout=extended_timeout
)
return

if self._address_table_size is None:
(status, addr_table_size) = await self.getConfigurationValue(
t.EzspConfigId.CONFIG_ADDRESS_TABLE_SIZE
)

if t.sl_Status.from_ember_status(status) != t.sl_Status.OK:
# Last-ditch effort
await self.setExtendedTimeout(
remoteEui64=ieee, extendedTimeout=extended_timeout
)
return

self._address_table_size = addr_table_size

# Replace a random entry in the address table
index = random.randint(0, self._address_table_size - 1)

await self.replaceAddressTableEntry(
addressTableIndex=index,
newEui64=ieee,
newId=nwk,
newExtendedTimeout=extended_timeout,
)
6 changes: 4 additions & 2 deletions bellows/zigbee/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,8 +749,10 @@ async def send_packet(self, packet: zigpy.types.ZigbeePacket) -> None:
async with self._req_lock:
if packet.dst.addr_mode == zigpy.types.AddrMode.NWK:
if packet.extended_timeout and device is not None:
await self._ezsp.setExtendedTimeout(
remoteEui64=device.ieee, extendedTimeout=True
await self._ezsp.set_extended_timeout(
nwk=device.nwk,
ieee=device.ieee,
extended_timeout=True,
)

if packet.source_route is not None:
Expand Down
14 changes: 14 additions & 0 deletions tests/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def form_network():
)

proto.factory_reset = AsyncMock(proto=proto.factory_reset)
proto.set_extended_timeout = AsyncMock(proto=proto.set_extended_timeout)

proto.read_link_keys = MagicMock()
proto.read_link_keys.return_value.__aiter__.return_value = [
Expand Down Expand Up @@ -842,6 +843,19 @@ async def test_send_packet_unicast_source_route(make_app, packet):
)


async def test_send_packet_unicast_extended_timeout(app, ieee, packet):
app.add_device(nwk=packet.dst.address, ieee=ieee)

await _test_send_packet_unicast(
app,
packet.replace(extended_timeout=True),
)

assert app._ezsp._protocol.set_extended_timeout.mock_calls == [
call(nwk=packet.dst.address, ieee=ieee, extended_timeout=True)
]


@patch("bellows.zigbee.application.RETRY_DELAYS", [0.01, 0.01, 0.01])
async def test_send_packet_unicast_retries_success(app, packet):
await _test_send_packet_unicast(
Expand Down
138 changes: 137 additions & 1 deletion tests/test_ezsp_v4.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from unittest.mock import MagicMock, call
from unittest.mock import MagicMock, call, patch

import pytest
import zigpy.state
Expand Down Expand Up @@ -379,3 +379,139 @@ async def test_read_counters(ezsp_f, length: int) -> None:
)

assert counters1 == counters2 == {t.EmberCounterType(i): i for i in range(length)}


async def test_set_extended_timeout_no_entry(ezsp_f) -> None:
# Typical invocation
ezsp_f.getExtendedTimeout.return_value = (t.Bool.false,)
ezsp_f.lookupNodeIdByEui64.return_value = (0xFFFF,) # No address table entry
ezsp_f.getConfigurationValue.return_value = (t.EmberStatus.SUCCESS, 8)
ezsp_f.replaceAddressTableEntry.return_value = (
t.EmberStatus.SUCCESS,
t.EUI64.convert("ff:ff:ff:ff:ff:ff:ff:ff"),
0xFFFF,
t.Bool.false,
)

with patch("bellows.ezsp.v4.random.randint") as mock_random:
mock_random.return_value = 0
await ezsp_f.set_extended_timeout(
nwk=0x1234,
ieee=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"),
extended_timeout=True,
)

assert ezsp_f.getExtendedTimeout.mock_calls == [
call(remoteEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"))
]
assert ezsp_f.lookupNodeIdByEui64.mock_calls == [
call(eui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"))
]
assert ezsp_f.getConfigurationValue.mock_calls == [
call(t.EzspConfigId.CONFIG_ADDRESS_TABLE_SIZE)
]
assert mock_random.mock_calls == [call(0, 8 - 1)]
assert ezsp_f.replaceAddressTableEntry.mock_calls == [
call(
addressTableIndex=0,
newEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"),
newId=0x1234,
newExtendedTimeout=True,
)
]

# The address table size is cached
with patch("bellows.ezsp.v4.random.randint") as mock_random:
mock_random.return_value = 1
await ezsp_f.set_extended_timeout(
nwk=0x1234,
ieee=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"),
extended_timeout=True,
)

# Still called only once
assert ezsp_f.getConfigurationValue.mock_calls == [
call(t.EzspConfigId.CONFIG_ADDRESS_TABLE_SIZE)
]

assert ezsp_f.replaceAddressTableEntry.mock_calls == [
call(
addressTableIndex=0,
newEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"),
newId=0x1234,
newExtendedTimeout=True,
),
call(
addressTableIndex=1,
newEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"),
newId=0x1234,
newExtendedTimeout=True,
),
]


async def test_set_extended_timeout_already_set(ezsp_f) -> None:
# No-op, it's already set
ezsp_f.setExtendedTimeout.return_value = ()
ezsp_f.getExtendedTimeout.return_value = (t.Bool.true,)

await ezsp_f.set_extended_timeout(
nwk=0x1234,
ieee=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"),
extended_timeout=True,
)

assert ezsp_f.getExtendedTimeout.mock_calls == [
call(remoteEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"))
]
assert ezsp_f.setExtendedTimeout.mock_calls == []


async def test_set_extended_timeout_already_have_entry(ezsp_f) -> None:
# An address table entry is present
ezsp_f.setExtendedTimeout.return_value = ()
ezsp_f.getExtendedTimeout.return_value = (t.Bool.false,)
ezsp_f.lookupNodeIdByEui64.return_value = (0x1234,)

await ezsp_f.set_extended_timeout(
nwk=0x1234,
ieee=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"),
extended_timeout=True,
)

assert ezsp_f.getExtendedTimeout.mock_calls == [
call(remoteEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"))
]
assert ezsp_f.lookupNodeIdByEui64.mock_calls == [
call(eui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"))
]
assert ezsp_f.setExtendedTimeout.mock_calls == [
call(
remoteEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"), extendedTimeout=True
)
]


async def test_set_extended_timeout_bad_table_size(ezsp_f) -> None:
ezsp_f.setExtendedTimeout.return_value = ()
ezsp_f.getExtendedTimeout.return_value = (t.Bool.false,)
ezsp_f.lookupNodeIdByEui64.return_value = (0xFFFF,)
ezsp_f.getConfigurationValue.return_value = (t.EmberStatus.ERR_FATAL, 0xFF)

with patch("bellows.ezsp.v4.random.randint") as mock_random:
mock_random.return_value = 0
await ezsp_f.set_extended_timeout(
nwk=0x1234,
ieee=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"),
extended_timeout=True,
)

assert ezsp_f.getExtendedTimeout.mock_calls == [
call(remoteEui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"))
]
assert ezsp_f.lookupNodeIdByEui64.mock_calls == [
call(eui64=t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"))
]
assert ezsp_f.getConfigurationValue.mock_calls == [
call(t.EzspConfigId.CONFIG_ADDRESS_TABLE_SIZE)
]
Loading