Skip to content

Commit

Permalink
Merge original request rospogrigio#491 from upstream branch
Browse files Browse the repository at this point in the history
  • Loading branch information
Tuen Lee committed Feb 14, 2022
1 parent f776007 commit 58e6d2d
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 35 deletions.
99 changes: 78 additions & 21 deletions custom_components/localtuya/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
import asyncio
import logging
from datetime import timedelta
import platform
import subprocess
from anyio import sleep

from sqlalchemy import true
from homeassistant.config_entries import ConfigEntry

from homeassistant.const import (
Expand Down Expand Up @@ -161,7 +166,6 @@ def async_connect(self):
async def _make_connection(self):
"""Subscribe localtuya entity events."""
self.debug("Connecting to %s", self._config_entry[CONF_HOST])

try:
self._interface = await pytuya.connect(
self._config_entry[CONF_HOST],
Expand Down Expand Up @@ -212,6 +216,37 @@ def _new_entity_handler(entity_id):
self._interface = None
self._connect_task = None

if self._interface is not None:
try:
self.debug("Retrieving initial state")
status = await self._interface.status()
if status is None or not status:
raise Exception("Failed to retrieve status")

self._interface.start_heartbeat()
self.status_updated(status)

except Exception: # pylint: disable=broad-except
try:
if self._interface is not None:
self.debug("Initial state update failed, trying refresh command")
await self._interface.refresh()

self.debug("Refresh completed, retrying initial state")
status = await self._interface.status()
if status is None or not status:
raise Exception("Failed to retrieve status")

self._interface.start_heartbeat()
self.status_updated(status)
else:
raise Exception("Interface went away")
except Exception: # pylint: disable=broad-except
self.exception(f"Refresh/initial status update of {self._config_entry[CONF_HOST]} failed")
if self._interface is not None:
await self._interface.close()
self._interface = None

async def _async_refresh(self, _now):
if self._interface is not None:
await self._interface.update_dps()
Expand Down Expand Up @@ -254,7 +289,13 @@ async def set_dps(self, states):
@callback
def status_updated(self, status):
"""Device updated status."""
self._status.update(status)
if status:
self._status.update(status)

signal = f"localtuya_{self._config_entry[CONF_DEVICE_ID]}"
async_dispatcher_send(self._hass, signal, self._status)
else:
self.debug("Received empty status update (likely for refresh)")
self._dispatch_status()

def _dispatch_status(self):
Expand Down Expand Up @@ -291,7 +332,10 @@ def __init__(self, hass, config_entry):

# Safety check
if not config_entry.get(CONF_IS_GATEWAY):
raise Exception("Device {0} is not a gateway but using TuyaGatewayDevice!", config_entry[CONF_DEVICE_ID])
raise Exception(
"Device {0} is not a gateway but using TuyaGatewayDevice!",
config_entry[CONF_DEVICE_ID],
)

@property
def connected(self):
Expand Down Expand Up @@ -397,15 +441,14 @@ async def _get_sub_device_status(self, cid, is_retry):
self._dispatch_event(GW_EVT_DISCONNECTED, None, cid)

def _dispatch_event(self, event, event_data, cid):
self.debug("Dispatching event %s to sub-device %s with data %s", event, cid, event_data)
self.debug(
"Dispatching event %s to sub-device %s with data %s", event, cid, event_data
)

async_dispatcher_send(
self._hass,
f"localtuya_subdevice_{cid}",
{
"event": event,
"event_data": event_data
}
{"event": event, "event_data": event_data},
)

async def _retry_sub_device_connection(self, _now):
Expand Down Expand Up @@ -468,7 +511,10 @@ def __init__(self, hass, config_entry):

# Safety check
if not config_entry.get(CONF_PARENT_GATEWAY):
raise Exception("Device {0} is not a sub-device but using TuyaSubDevice!", config_entry[CONF_DEVICE_ID])
raise Exception(
"Device {0} is not a sub-device but using TuyaSubDevice!",
config_entry[CONF_DEVICE_ID],
)

# Populate dps list from entities
for entity in config_entry[CONF_ENTITIES]:
Expand Down Expand Up @@ -506,9 +552,9 @@ def _new_entity_handler(entity_id):
self._hass, signal, _new_entity_handler
)

self._async_dispatch_gateway_request(GW_REQ_ADD, {
"dps": self.dps_to_request
})
self._async_dispatch_gateway_request(
GW_REQ_ADD, {"dps": self.dps_to_request}
)

self._is_added = True

Expand All @@ -535,7 +581,9 @@ def _handle_gateway_event(self, data):
self.debug("Invalid event %s from gateway", event)

def _async_dispatch_gateway_request(self, request, content):
self.debug("Dispatching request %s to gateway with content %s", request, content)
self.debug(
"Dispatching request %s to gateway with content %s", request, content
)
asyncio.create_task(self._gateway_request_task(request, content))

async def _gateway_request_task(self, request, content):
Expand All @@ -557,7 +605,10 @@ async def _gateway_request_task(self, request, content):
return

if self._pending_request["retry_count"] >= SUB_DEVICE_DISPATCH_RETRY_MAX:
self.debug("Request %s exceeded maximum retries", self._pending_request["request"])
self.debug(
"Request %s exceeded maximum retries",
self._pending_request["request"],
)
self._pending_request["request"] = None
return

Expand All @@ -577,10 +628,13 @@ async def _gateway_request_task(self, request, content):
async def set_dp(self, state, dp_index):
"""Change value of a DP of the Tuya device."""
if self._is_connected:
self._async_dispatch_gateway_request(GW_REQ_SET_DP, {
"value": state,
"dp_index": dp_index,
})
self._async_dispatch_gateway_request(
GW_REQ_SET_DP,
{
"value": state,
"dp_index": dp_index,
},
)
else:
self.error(
"Not connected to device %s", self._config_entry[CONF_FRIENDLY_NAME]
Expand All @@ -589,9 +643,12 @@ async def set_dp(self, state, dp_index):
async def set_dps(self, states):
"""Change value of DPs of the Tuya device."""
if self._is_connected:
self._async_dispatch_gateway_request(GW_REQ_SET_DPS, {
"dps": states,
})
self._async_dispatch_gateway_request(
GW_REQ_SET_DPS,
{
"dps": states,
},
)
else:
self.error(
"Not connected to device %s", self._config_entry[CONF_FRIENDLY_NAME]
Expand Down
63 changes: 49 additions & 14 deletions custom_components/localtuya/pytuya/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
STATUS = "status"
HEARTBEAT = "heartbeat"
UPDATEDPS = "updatedps" # Request refresh of DPS
REFRESH = "refresh"

PROTOCOL_VERSION_BYTES_31 = b"3.1"
PROTOCOL_VERSION_BYTES_33 = b"3.3"
Expand Down Expand Up @@ -104,6 +105,7 @@
STATUS: {"hexByte": COMMAND_DP_QUERY, "command": {"cid": ""}},
SET: {"hexByte": COMMAND_SET, "command": {"cid": "", "t": ""}},
HEARTBEAT: {"hexByte": COMMAND_HEARTBEAT, "command": {}},
REFRESH: {"hexByte": COMMAND_UPDATE_DPS, "command": {"gwId": "", "devId": "", "uid": "", "t": "", "dpId": ""}},
},
# TYPE_0D should never be used with gateways
}
Expand All @@ -119,9 +121,12 @@
SET: {"hexByte": COMMAND_SET, "command": {"devId": "", "uid": "", "t": ""}},
HEARTBEAT: {"hexByte": COMMAND_HEARTBEAT, "command": {}},
UPDATEDPS: {"hexByte": COMMAND_UPDATE_DPS, "command": {"dpId": [18, 19, 20]}},
REFRESH: {"hexByte": COMMAND_UPDATE_DPS, "command": {"gwId": "", "devId": "", "uid": "", "t": "", "dpId": ""}},
},
}

REFRESH_IDS = [4, 5, 6, 18, 19, 20]


class TuyaLoggingAdapter(logging.LoggerAdapter):
"""Adapter that adds device id to all log points."""
Expand Down Expand Up @@ -231,9 +236,10 @@ def _unpad(data):
class MessageDispatcher(ContextualLogger):
"""Buffer and dispatcher for Tuya messages."""

# Heartbeats always respond with sequence number 0, so they can't be waited for like
# other messages. This is a hack to allow waiting for heartbeats.
# Heartbeats and refreshes always respond with sequence number 0, so they can't be waited for like
# other messages. This is a hack to allow waiting for them.
HEARTBEAT_SEQNO = -100
REFRESH_SEQNO = -101

def __init__(self, dev_id, listener):
"""Initialize a new MessageBuffer."""
Expand Down Expand Up @@ -317,10 +323,24 @@ def _dispatch(self, msg):
self.listeners[self.HEARTBEAT_SEQNO] = msg
sem.release()
elif msg.cmd == COMMAND_UPDATE_DPS:
self.debug("Got normal updatedps response")
self.debug("Got normal refresh response")
if self.REFRESH_SEQNO in self.listeners:
sem = self.listeners[self.REFRESH_SEQNO]
self.listeners[self.REFRESH_SEQNO] = msg
if isinstance(sem, asyncio.Semaphore):
sem.release()
elif msg.cmd == COMMAND_STATUS:
self.debug("Got status update")
self.listener(msg)
# If we have an open refresh call then this is for it.
# Some devices send 0x12 and 0x08 in response to a refresh.
# Empty DPS responses here are always for refresh but hey we haven't decoded yet to know
if self.REFRESH_SEQNO in self.listeners and isinstance(self.listeners[self.REFRESH_SEQNO], asyncio.Semaphore):
self.debug("Got status type refresh response")
sem = self.listeners[self.REFRESH_SEQNO]
self.listeners[self.REFRESH_SEQNO] = msg
sem.release()
else:
self.debug("Got status update")
self.listener(msg)
else:
self.debug(
"Got message type %d for unknown listener %d: %s",
Expand Down Expand Up @@ -400,7 +420,10 @@ def _status_update(self, msg):

def connection_made(self, transport):
"""Did connect to the device."""
self.transport = transport
self.on_connected.set_result(True)

def start_heartbeat(self):
async def heartbeat_loop():
"""Continuously send heart beat updates."""
self.debug("Started heartbeat loop")
Expand All @@ -422,9 +445,9 @@ async def heartbeat_loop():
self.transport = None
transport.close()

self.transport = transport
self.on_connected.set_result(True)
self.heartbeater = self.loop.create_task(heartbeat_loop())
self.transport = transport
self.on_connected.set_result(True)
self.heartbeater = self.loop.create_task(heartbeat_loop())

def data_received(self, data):
"""Received data from device."""
Expand Down Expand Up @@ -468,12 +491,13 @@ async def exchange(self, command, dps=None, cid=None):
payload = self._generate_payload(command, dps, cid)
dev_type = self.dev_type

# Wait for special sequence number if heartbeat
seqno = (
MessageDispatcher.HEARTBEAT_SEQNO
if command == HEARTBEAT
else (self.seqno - 1)
)
# Wait for special sequence number if heartbeat or refresh
seqno = (self.seqno - 1)

if command == HEARTBEAT:
seqno = MessageDispatcher.HEARTBEAT_SEQNO
elif command == REFRESH:
seqno = MessageDispatcher.REFRESH_SEQNO

self.transport.write(payload)
msg = await self.dispatcher.wait_for(seqno)
Expand Down Expand Up @@ -517,6 +541,15 @@ async def status(self, cid=None):
async def heartbeat(self):
"""Send a heartbeat message."""
return await self.exchange(HEARTBEAT)

async def refresh(self):
"""Send a refresh message (3.3 only)."""
if self.version == 3.3:
self.dev_type = "type_0a"
self.debug("refresh switching to dev_type %s", self.dev_type)
return await self.exchange(REFRESH)
else:
return True

async def update_dps(self, dps=None):
"""
Expand Down Expand Up @@ -720,6 +753,8 @@ def _generate_payload(self, command, data=None, cid=None):
json_data["cid"] = cid # for Zigbee gateways, cid specifies the sub-device
if "t" in json_data:
json_data["t"] = str(int(time.time()))
if "dpId" in json_data:
json_data["dpId"] = REFRESH_IDS

if data is not None:
if "dpId" in json_data:
Expand Down

0 comments on commit 58e6d2d

Please sign in to comment.