From 58e6d2dd7683a237f9a46a5880e9a38eda7e3079 Mon Sep 17 00:00:00 2001 From: Tuen Lee Date: Mon, 14 Feb 2022 11:23:30 +0100 Subject: [PATCH] Merge original request #491 from upstream branch --- custom_components/localtuya/common.py | 99 +++++++++++++++---- .../localtuya/pytuya/__init__.py | 63 +++++++++--- 2 files changed, 127 insertions(+), 35 deletions(-) diff --git a/custom_components/localtuya/common.py b/custom_components/localtuya/common.py index 6b0506d19..d5c0d8c08 100644 --- a/custom_components/localtuya/common.py +++ b/custom_components/localtuya/common.py @@ -2,6 +2,11 @@ import asyncio import logging from datetime import timedelta +import platform +import subprocess +from anyio import sleep + +from sqlalchemy import true from homeassistant.config_entries import ConfigEntry from homeassistant.const import ( @@ -161,7 +166,6 @@ def async_connect(self): async def _make_connection(self): """Subscribe localtuya entity events.""" self.debug("Connecting to %s", self._config_entry[CONF_HOST]) - try: self._interface = await pytuya.connect( self._config_entry[CONF_HOST], @@ -212,6 +216,37 @@ def _new_entity_handler(entity_id): self._interface = None self._connect_task = None + if self._interface is not None: + try: + self.debug("Retrieving initial state") + status = await self._interface.status() + if status is None or not status: + raise Exception("Failed to retrieve status") + + self._interface.start_heartbeat() + self.status_updated(status) + + except Exception: # pylint: disable=broad-except + try: + if self._interface is not None: + self.debug("Initial state update failed, trying refresh command") + await self._interface.refresh() + + self.debug("Refresh completed, retrying initial state") + status = await self._interface.status() + if status is None or not status: + raise Exception("Failed to retrieve status") + + self._interface.start_heartbeat() + self.status_updated(status) + else: + raise Exception("Interface went away") + except Exception: # pylint: disable=broad-except + self.exception(f"Refresh/initial status update of {self._config_entry[CONF_HOST]} failed") + if self._interface is not None: + await self._interface.close() + self._interface = None + async def _async_refresh(self, _now): if self._interface is not None: await self._interface.update_dps() @@ -254,7 +289,13 @@ async def set_dps(self, states): @callback def status_updated(self, status): """Device updated status.""" - self._status.update(status) + if status: + self._status.update(status) + + signal = f"localtuya_{self._config_entry[CONF_DEVICE_ID]}" + async_dispatcher_send(self._hass, signal, self._status) + else: + self.debug("Received empty status update (likely for refresh)") self._dispatch_status() def _dispatch_status(self): @@ -291,7 +332,10 @@ def __init__(self, hass, config_entry): # Safety check if not config_entry.get(CONF_IS_GATEWAY): - raise Exception("Device {0} is not a gateway but using TuyaGatewayDevice!", config_entry[CONF_DEVICE_ID]) + raise Exception( + "Device {0} is not a gateway but using TuyaGatewayDevice!", + config_entry[CONF_DEVICE_ID], + ) @property def connected(self): @@ -397,15 +441,14 @@ async def _get_sub_device_status(self, cid, is_retry): self._dispatch_event(GW_EVT_DISCONNECTED, None, cid) def _dispatch_event(self, event, event_data, cid): - self.debug("Dispatching event %s to sub-device %s with data %s", event, cid, event_data) + self.debug( + "Dispatching event %s to sub-device %s with data %s", event, cid, event_data + ) async_dispatcher_send( self._hass, f"localtuya_subdevice_{cid}", - { - "event": event, - "event_data": event_data - } + {"event": event, "event_data": event_data}, ) async def _retry_sub_device_connection(self, _now): @@ -468,7 +511,10 @@ def __init__(self, hass, config_entry): # Safety check if not config_entry.get(CONF_PARENT_GATEWAY): - raise Exception("Device {0} is not a sub-device but using TuyaSubDevice!", config_entry[CONF_DEVICE_ID]) + raise Exception( + "Device {0} is not a sub-device but using TuyaSubDevice!", + config_entry[CONF_DEVICE_ID], + ) # Populate dps list from entities for entity in config_entry[CONF_ENTITIES]: @@ -506,9 +552,9 @@ def _new_entity_handler(entity_id): self._hass, signal, _new_entity_handler ) - self._async_dispatch_gateway_request(GW_REQ_ADD, { - "dps": self.dps_to_request - }) + self._async_dispatch_gateway_request( + GW_REQ_ADD, {"dps": self.dps_to_request} + ) self._is_added = True @@ -535,7 +581,9 @@ def _handle_gateway_event(self, data): self.debug("Invalid event %s from gateway", event) def _async_dispatch_gateway_request(self, request, content): - self.debug("Dispatching request %s to gateway with content %s", request, content) + self.debug( + "Dispatching request %s to gateway with content %s", request, content + ) asyncio.create_task(self._gateway_request_task(request, content)) async def _gateway_request_task(self, request, content): @@ -557,7 +605,10 @@ async def _gateway_request_task(self, request, content): return if self._pending_request["retry_count"] >= SUB_DEVICE_DISPATCH_RETRY_MAX: - self.debug("Request %s exceeded maximum retries", self._pending_request["request"]) + self.debug( + "Request %s exceeded maximum retries", + self._pending_request["request"], + ) self._pending_request["request"] = None return @@ -577,10 +628,13 @@ async def _gateway_request_task(self, request, content): async def set_dp(self, state, dp_index): """Change value of a DP of the Tuya device.""" if self._is_connected: - self._async_dispatch_gateway_request(GW_REQ_SET_DP, { - "value": state, - "dp_index": dp_index, - }) + self._async_dispatch_gateway_request( + GW_REQ_SET_DP, + { + "value": state, + "dp_index": dp_index, + }, + ) else: self.error( "Not connected to device %s", self._config_entry[CONF_FRIENDLY_NAME] @@ -589,9 +643,12 @@ async def set_dp(self, state, dp_index): async def set_dps(self, states): """Change value of DPs of the Tuya device.""" if self._is_connected: - self._async_dispatch_gateway_request(GW_REQ_SET_DPS, { - "dps": states, - }) + self._async_dispatch_gateway_request( + GW_REQ_SET_DPS, + { + "dps": states, + }, + ) else: self.error( "Not connected to device %s", self._config_entry[CONF_FRIENDLY_NAME] diff --git a/custom_components/localtuya/pytuya/__init__.py b/custom_components/localtuya/pytuya/__init__.py index 7290dc6ee..3a8cc2561 100644 --- a/custom_components/localtuya/pytuya/__init__.py +++ b/custom_components/localtuya/pytuya/__init__.py @@ -63,6 +63,7 @@ STATUS = "status" HEARTBEAT = "heartbeat" UPDATEDPS = "updatedps" # Request refresh of DPS +REFRESH = "refresh" PROTOCOL_VERSION_BYTES_31 = b"3.1" PROTOCOL_VERSION_BYTES_33 = b"3.3" @@ -104,6 +105,7 @@ STATUS: {"hexByte": COMMAND_DP_QUERY, "command": {"cid": ""}}, SET: {"hexByte": COMMAND_SET, "command": {"cid": "", "t": ""}}, HEARTBEAT: {"hexByte": COMMAND_HEARTBEAT, "command": {}}, + REFRESH: {"hexByte": COMMAND_UPDATE_DPS, "command": {"gwId": "", "devId": "", "uid": "", "t": "", "dpId": ""}}, }, # TYPE_0D should never be used with gateways } @@ -119,9 +121,12 @@ SET: {"hexByte": COMMAND_SET, "command": {"devId": "", "uid": "", "t": ""}}, HEARTBEAT: {"hexByte": COMMAND_HEARTBEAT, "command": {}}, UPDATEDPS: {"hexByte": COMMAND_UPDATE_DPS, "command": {"dpId": [18, 19, 20]}}, + REFRESH: {"hexByte": COMMAND_UPDATE_DPS, "command": {"gwId": "", "devId": "", "uid": "", "t": "", "dpId": ""}}, }, } +REFRESH_IDS = [4, 5, 6, 18, 19, 20] + class TuyaLoggingAdapter(logging.LoggerAdapter): """Adapter that adds device id to all log points.""" @@ -231,9 +236,10 @@ def _unpad(data): class MessageDispatcher(ContextualLogger): """Buffer and dispatcher for Tuya messages.""" - # Heartbeats always respond with sequence number 0, so they can't be waited for like - # other messages. This is a hack to allow waiting for heartbeats. + # Heartbeats and refreshes always respond with sequence number 0, so they can't be waited for like + # other messages. This is a hack to allow waiting for them. HEARTBEAT_SEQNO = -100 + REFRESH_SEQNO = -101 def __init__(self, dev_id, listener): """Initialize a new MessageBuffer.""" @@ -317,10 +323,24 @@ def _dispatch(self, msg): self.listeners[self.HEARTBEAT_SEQNO] = msg sem.release() elif msg.cmd == COMMAND_UPDATE_DPS: - self.debug("Got normal updatedps response") + self.debug("Got normal refresh response") + if self.REFRESH_SEQNO in self.listeners: + sem = self.listeners[self.REFRESH_SEQNO] + self.listeners[self.REFRESH_SEQNO] = msg + if isinstance(sem, asyncio.Semaphore): + sem.release() elif msg.cmd == COMMAND_STATUS: - self.debug("Got status update") - self.listener(msg) + # If we have an open refresh call then this is for it. + # Some devices send 0x12 and 0x08 in response to a refresh. + # Empty DPS responses here are always for refresh but hey we haven't decoded yet to know + if self.REFRESH_SEQNO in self.listeners and isinstance(self.listeners[self.REFRESH_SEQNO], asyncio.Semaphore): + self.debug("Got status type refresh response") + sem = self.listeners[self.REFRESH_SEQNO] + self.listeners[self.REFRESH_SEQNO] = msg + sem.release() + else: + self.debug("Got status update") + self.listener(msg) else: self.debug( "Got message type %d for unknown listener %d: %s", @@ -400,7 +420,10 @@ def _status_update(self, msg): def connection_made(self, transport): """Did connect to the device.""" + self.transport = transport + self.on_connected.set_result(True) + def start_heartbeat(self): async def heartbeat_loop(): """Continuously send heart beat updates.""" self.debug("Started heartbeat loop") @@ -422,9 +445,9 @@ async def heartbeat_loop(): self.transport = None transport.close() - self.transport = transport - self.on_connected.set_result(True) - self.heartbeater = self.loop.create_task(heartbeat_loop()) + self.transport = transport + self.on_connected.set_result(True) + self.heartbeater = self.loop.create_task(heartbeat_loop()) def data_received(self, data): """Received data from device.""" @@ -468,12 +491,13 @@ async def exchange(self, command, dps=None, cid=None): payload = self._generate_payload(command, dps, cid) dev_type = self.dev_type - # Wait for special sequence number if heartbeat - seqno = ( - MessageDispatcher.HEARTBEAT_SEQNO - if command == HEARTBEAT - else (self.seqno - 1) - ) + # Wait for special sequence number if heartbeat or refresh + seqno = (self.seqno - 1) + + if command == HEARTBEAT: + seqno = MessageDispatcher.HEARTBEAT_SEQNO + elif command == REFRESH: + seqno = MessageDispatcher.REFRESH_SEQNO self.transport.write(payload) msg = await self.dispatcher.wait_for(seqno) @@ -517,6 +541,15 @@ async def status(self, cid=None): async def heartbeat(self): """Send a heartbeat message.""" return await self.exchange(HEARTBEAT) + + async def refresh(self): + """Send a refresh message (3.3 only).""" + if self.version == 3.3: + self.dev_type = "type_0a" + self.debug("refresh switching to dev_type %s", self.dev_type) + return await self.exchange(REFRESH) + else: + return True async def update_dps(self, dps=None): """ @@ -720,6 +753,8 @@ def _generate_payload(self, command, data=None, cid=None): json_data["cid"] = cid # for Zigbee gateways, cid specifies the sub-device if "t" in json_data: json_data["t"] = str(int(time.time())) + if "dpId" in json_data: + json_data["dpId"] = REFRESH_IDS if data is not None: if "dpId" in json_data: