From 758803d0c7db413bab910a2563f12ebd171c8d90 Mon Sep 17 00:00:00 2001 From: Conrad Kreyling Date: Sat, 30 Dec 2023 15:04:58 -0500 Subject: [PATCH] Reuse ZDO Initializers to create Endpoint objects on EZSP device (#599) * Ensure device endpoints sync with register_endpoints This way if we need a reference to one we can grab it in a way we expect * Make the linter and tests happy * Better names, better comments. * Merge complex endpoint functionality into simple endpoint * Refactor unit tests to allow slightly better mocking of startup * Simplify endpoint creation --------- Co-authored-by: puddly <32534428+puddly@users.noreply.github.com> --- bellows/zigbee/application.py | 28 ++++++++--- bellows/zigbee/device.py | 41 +++++++++++++--- tests/test_application.py | 90 ++++++++++++++++++++--------------- 3 files changed, 108 insertions(+), 51 deletions(-) diff --git a/bellows/zigbee/application.py b/bellows/zigbee/application.py index d486ff05..6bb3b919 100644 --- a/bellows/zigbee/application.py +++ b/bellows/zigbee/application.py @@ -80,6 +80,7 @@ class ControllerApplication(zigpy.application.ControllerApplication): def __init__(self, config: dict): super().__init__(config) self._ctrl_event = asyncio.Event() + self._created_device_endpoints: list[zdo_t.SimpleDescriptor] = [] self._ezsp = None self._multicast = None self._mfg_id_task: asyncio.Task | None = None @@ -116,9 +117,12 @@ async def add_endpoint(self, descriptor: zdo_t.SimpleDescriptor) -> None: descriptor.input_clusters, descriptor.output_clusters, ) + if status != t.EmberStatus.SUCCESS: raise StackAlreadyRunning() + self._created_device_endpoints.append(descriptor) + async def cleanup_tc_link_key(self, ieee: t.EUI64) -> None: """Remove tc link_key for the given device.""" (index,) = await self._ezsp.findKeyTableEntry(ieee, True) @@ -150,6 +154,8 @@ async def connect(self) -> None: raise self._ezsp = ezsp + + self._created_device_endpoints.clear() await self.register_endpoints() async def _ensure_network_running(self) -> bool: @@ -198,10 +204,15 @@ async def start_network(self): ezsp.add_callback(self.ezsp_callback_handler) self.controller_event.set() + group_membership = {} + try: db_device = self.get_device(ieee=self.state.node_info.ieee) except KeyError: - db_device = None + pass + else: + if 1 in db_device.endpoints: + group_membership = db_device.endpoints[1].member_of ezsp_device = zigpy.device.Device( application=self, @@ -210,15 +221,18 @@ async def start_network(self): ) self.devices[self.state.node_info.ieee] = ezsp_device - # The coordinator device does not respond to attribute reads - ezsp_device.endpoints[1] = EZSPEndpoint(ezsp_device, 1) - ezsp_device.model = ezsp_device.endpoints[1].model - ezsp_device.manufacturer = ezsp_device.endpoints[1].manufacturer + # The coordinator device does not respond to attribute reads so we have to + # divine the internal NCP state. + for zdo_desc in self._created_device_endpoints: + ep = EZSPEndpoint(ezsp_device, zdo_desc.endpoint, zdo_desc) + ezsp_device.endpoints[zdo_desc.endpoint] = ep + ezsp_device.model = ep.model + ezsp_device.manufacturer = ep.manufacturer + await ezsp_device.schedule_initialize() # Group membership is stored in the database for EZSP coordinators - if db_device is not None and 1 in db_device.endpoints: - ezsp_device.endpoints[1].member_of.update(db_device.endpoints[1].member_of) + ezsp_device.endpoints[1].member_of.update(group_membership) self._multicast = bellows.multicast.Multicast(ezsp) await self._multicast.startup(ezsp_device) diff --git a/bellows/zigbee/device.py b/bellows/zigbee/device.py index fcfff160..9e4a6e7b 100644 --- a/bellows/zigbee/device.py +++ b/bellows/zigbee/device.py @@ -1,22 +1,51 @@ from __future__ import annotations import logging -import typing import zigpy.device import zigpy.endpoint -import zigpy.util -import zigpy.zdo +import zigpy.profiles.zgp +import zigpy.profiles.zha +import zigpy.profiles.zll +import zigpy.zdo.types as zdo_t import bellows.types as t -if typing.TYPE_CHECKING: - import zigpy.application # pragma: no cover - LOGGER = logging.getLogger(__name__) +PROFILE_TO_DEVICE_TYPE = { + zigpy.profiles.zha.PROFILE_ID: zigpy.profiles.zha.DeviceType, + zigpy.profiles.zll.PROFILE_ID: zigpy.profiles.zll.DeviceType, + zigpy.profiles.zgp.PROFILE_ID: zigpy.profiles.zgp.DeviceType, +} + class EZSPEndpoint(zigpy.endpoint.Endpoint): + def __init__( + self, + device: zigpy.device.Device, + endpoint_id: int, + descriptor: zdo_t.SimpleDescriptor, + ) -> None: + super().__init__(device, endpoint_id) + + self.profile_id = descriptor.profile + + if self.profile_id in PROFILE_TO_DEVICE_TYPE: + self.device_type = PROFILE_TO_DEVICE_TYPE[self.profile_id]( + descriptor.device_type + ) + else: + self.device_type = descriptor.device_type + + for cluster in descriptor.input_clusters: + self.add_input_cluster(cluster) + + for cluster in descriptor.output_clusters: + self.add_output_cluster(cluster) + + self.status = zigpy.endpoint.Status.ZDO_INIT + @property def manufacturer(self) -> str: """Manufacturer.""" diff --git a/tests/test_application.py b/tests/test_application.py index aceeabb3..c992b8e3 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1,4 +1,5 @@ import asyncio +import contextlib import logging from unittest.mock import AsyncMock, MagicMock, PropertyMock, call, patch, sentinel @@ -114,7 +115,6 @@ def aps(): return f -@patch("zigpy.device.Device._initialize", new=AsyncMock()) def _create_app_for_startup( app, nwk_type, @@ -206,6 +206,14 @@ async def mock_leave(*args, **kwargs): ), ] ) + ezsp_mock.getMulticastTableEntry = AsyncMock( + return_value=[ + t.EmberStatus.SUCCESS, + t.EmberMulticastTableEntry(multicastId=0x0000, endpoint=0, networkIndex=0), + ] + ) + ezsp_mock.setMulticastTableEntry = AsyncMock(return_value=[t.EmberStatus.SUCCESS]) + app.permit = AsyncMock() def form_network(): @@ -220,10 +228,11 @@ def form_network(): return ezsp_mock -async def _test_startup( +@contextlib.contextmanager +def mock_for_startup( app, - nwk_type, ieee, + nwk_type=t.EmberNodeType.COORDINATOR, auto_form=False, init=0, ezsp_version=4, @@ -234,10 +243,25 @@ async def _test_startup( app, nwk_type, ieee, auto_form, init, ezsp_version, board_info, network_state ) - p1 = patch("bellows.ezsp.EZSP", return_value=ezsp_mock) - p2 = patch.object(bellows.multicast.Multicast, "startup") + with patch("bellows.ezsp.EZSP", return_value=ezsp_mock), patch( + "zigpy.device.Device._initialize", new=AsyncMock() + ): + yield ezsp_mock + - with p1, p2 as multicast_mock: +async def _test_startup( + app, + nwk_type, + ieee, + auto_form=False, + init=0, + ezsp_version=4, + board_info=True, + network_state=t.EmberNetworkStatus.JOINED_NETWORK, +): + with mock_for_startup( + app, ieee, nwk_type, auto_form, init, ezsp_version, board_info, network_state + ) as ezsp_mock: await app.startup(auto_form=auto_form) if ezsp_version > 6: @@ -247,7 +271,6 @@ async def _test_startup( assert ezsp_mock.write_config.call_count == 1 assert ezsp_mock.addEndpoint.call_count >= 2 - assert multicast_mock.await_count == 1 async def test_startup(app, ieee): @@ -1166,7 +1189,7 @@ async def test_shutdown(app): @pytest.fixture def coordinator(app, ieee): dev = zigpy.device.Device(app, ieee, 0x0000) - dev.endpoints[1] = bellows.zigbee.device.EZSPEndpoint(dev, 1) + dev.endpoints[1] = bellows.zigbee.device.EZSPEndpoint(dev, 1, MagicMock()) dev.model = dev.endpoints[1].model dev.manufacturer = dev.endpoints[1].manufacturer @@ -1505,42 +1528,32 @@ async def test_ensure_network_running_not_joined_success(app): async def test_startup_coordinator_existing_groups_joined(app, ieee): """Coordinator joins groups loaded from the database.""" + with mock_for_startup(app, ieee) as ezsp_mock: + await app.connect() - app._ensure_network_running = AsyncMock() - app._ezsp.update_policies = AsyncMock() - app.load_network_info = AsyncMock() - app.state.node_info.ieee = ieee - - db_device = app.add_device(ieee, 0x0000) - db_ep = db_device.add_endpoint(1) - - app.groups.add_group(0x1234, "Group Name", suppress_event=True) - app.groups[0x1234].add_member(db_ep, suppress_event=True) + db_device = app.add_device(ieee, 0x0000) + db_ep = db_device.add_endpoint(1) - p1 = patch.object(bellows.multicast.Multicast, "_initialize") - p2 = patch.object(bellows.multicast.Multicast, "subscribe") + app.groups.add_group(0x1234, "Group Name", suppress_event=True) + app.groups[0x1234].add_member(db_ep, suppress_event=True) - with p1 as p1, p2 as p2: await app.start_network() - p2.assert_called_once_with(0x1234) + assert ezsp_mock.setMulticastTableEntry.mock_calls == [ + call( + 0, + t.EmberMulticastTableEntry(multicastId=0x1234, endpoint=1, networkIndex=0), + ) + ] async def test_startup_new_coordinator_no_groups_joined(app, ieee): """Coordinator freshy added to the database has no groups to join.""" - - app._ensure_network_running = AsyncMock() - app._ezsp.update_policies = AsyncMock() - app.load_network_info = AsyncMock() - app.state.node_info.ieee = ieee - - p1 = patch.object(bellows.multicast.Multicast, "_initialize") - p2 = patch.object(bellows.multicast.Multicast, "subscribe") - - with p1 as p1, p2 as p2: + with mock_for_startup(app, ieee) as ezsp_mock: + await app.connect() await app.start_network() - p2.assert_not_called() + assert ezsp_mock.setMulticastTableEntry.mock_calls == [] @pytest.mark.parametrize( @@ -1628,22 +1641,23 @@ async def test_connect_failure( assert len(ezsp_mock.close.mock_calls) == 1 -async def test_repair_tclk_partner_ieee(app: ControllerApplication) -> None: +async def test_repair_tclk_partner_ieee( + app: ControllerApplication, ieee: zigpy_t.EUI64 +) -> None: """Test that EZSP is reset after repairing TCLK.""" - app._ensure_network_running = AsyncMock() app._reset = AsyncMock() - app.load_network_info = AsyncMock() - with patch( + with mock_for_startup(app, ieee), patch( "bellows.zigbee.repairs.fix_invalid_tclk_partner_ieee", AsyncMock(return_value=False), ): + await app.connect() await app.start_network() assert len(app._reset.mock_calls) == 0 app._reset.reset_mock() - with patch( + with mock_for_startup(app, ieee), patch( "bellows.zigbee.repairs.fix_invalid_tclk_partner_ieee", AsyncMock(return_value=True), ):