Skip to content

Commit

Permalink
Support adding and removing entities to devices
Browse files Browse the repository at this point in the history
New entities can be added via + button under Integrations (like when
adding a new device) and entities can be removed from Options.
  • Loading branch information
postlund committed Nov 25, 2020
1 parent 095f649 commit 7e2f55a
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 46 deletions.
34 changes: 23 additions & 11 deletions custom_components/localtuya/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@

from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry
from homeassistant.const import (
CONF_ID,
CONF_DEVICE_ID,
CONF_ENTITIES,
CONF_HOST,
Expand All @@ -68,8 +69,9 @@
)
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.reload import async_integration_yaml_config
import homeassistant.helpers.entity_registry as er

from .common import TuyaDevice
from .common import TuyaDevice, async_config_entry_by_device_id
from .config_flow import config_schema
from .const import CONF_PRODUCT_KEY, DATA_DISCOVERY, DOMAIN, TUYA_DEVICE
from .discovery import TuyaDiscovery
Expand Down Expand Up @@ -117,14 +119,6 @@ async def _handle_reload(service):

await asyncio.gather(*reload_tasks)

def _entry_by_device_id(device_id):
"""Look up config entry by device id."""
current_entries = hass.config_entries.async_entries(DOMAIN)
for entry in current_entries:
if entry.data[CONF_DEVICE_ID] == device_id:
return entry
return None

def _device_discovered(device):
"""Update address of device if it has changed."""
device_ip = device["ip"]
Expand All @@ -133,7 +127,7 @@ def _device_discovered(device):

# If device is not in cache, check if a config entry exists
if device_id not in device_cache:
entry = _entry_by_device_id(device_id)
entry = async_config_entry_by_device_id(hass, device_id)
if entry:
# Save address from config entry in cache to trigger
# potential update below
Expand All @@ -142,7 +136,7 @@ def _device_discovered(device):
if device_id not in device_cache:
return

entry = _entry_by_device_id(device_id)
entry = async_config_entry_by_device_id(hass, device_id)
if entry is None:
return

Expand Down Expand Up @@ -211,6 +205,8 @@ async def setup_entities():
)
device.connect()

await async_remove_orphan_entities(hass, entry)

hass.async_create_task(setup_entities())

return True
Expand Down Expand Up @@ -240,3 +236,19 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry):
async def update_listener(hass, config_entry):
"""Update listener."""
await hass.config_entries.async_reload(config_entry.entry_id)


async def async_remove_orphan_entities(hass, entry):
"""Remove entities associated with config entry that has been removed."""
ent_reg = await er.async_get_registry(hass)
entities = {
int(ent.unique_id.split("_")[-1]): ent.entity_id
for ent in er.async_entries_for_config_entry(ent_reg, entry.entry_id)
}

for entity in entry.data[CONF_ENTITIES]:
if entity[CONF_ID] in entities:
del entities[entity[CONF_ID]]

for entity_id in entities.values():
ent_reg.async_remove(entity_id)
10 changes: 10 additions & 0 deletions custom_components/localtuya/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,16 @@ def get_entity_config(config_entry, dp_id):
raise Exception(f"missing entity config for id {dp_id}")


@callback
def async_config_entry_by_device_id(hass, device_id):
"""Look up config entry by device id."""
current_entries = hass.config_entries.async_entries(DOMAIN)
for entry in current_entries:
if entry.data[CONF_DEVICE_ID] == device_id:
return entry
return None


class TuyaDevice(pytuya.TuyaListener, pytuya.ContextualLogger):
"""Cache wrapper for pytuya.TuyaInterface."""

Expand Down
105 changes: 74 additions & 31 deletions custom_components/localtuya/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from homeassistant.core import callback

from . import pytuya
from .common import pytuya, async_config_entry_by_device_id
from .const import CONF_DPS_STRINGS # pylint: disable=unused-import
from .const import (
CONF_LOCAL_KEY,
Expand Down Expand Up @@ -46,14 +46,6 @@
}
)

OPTIONS_SCHEMA = vol.Schema(
{
vol.Required(CONF_FRIENDLY_NAME): str,
vol.Required(CONF_HOST): str,
vol.Required(CONF_LOCAL_KEY): str,
vol.Required(CONF_PROTOCOL_VERSION, default="3.3"): vol.In(["3.1", "3.3"]),
}
)

DEVICE_SCHEMA = vol.Schema(
{
Expand All @@ -70,11 +62,33 @@
)


def user_schema(devices):
def user_schema(devices, entries):
"""Create schema for user step."""
devices = [f"{ip} ({dev['gwId']})" for ip, dev in devices.items()]
devices = {dev_id: dev["ip"] for dev_id, dev in devices.items()}
devices.update(
{ent.data[CONF_DEVICE_ID]: ent.data[CONF_FRIENDLY_NAME] for ent in entries}
)
device_list = [f"{key} ({value})" for key, value in devices.items()]
return vol.Schema(
{vol.Required(DISCOVERED_DEVICE): vol.In(device_list + [CUSTOM_DEVICE])}
)


def options_schema(entities):
"""Create schema for options."""
entity_names = [
f"{entity[CONF_ID]} {entity[CONF_FRIENDLY_NAME]}" for entity in entities
]
return vol.Schema(
{vol.Required(DISCOVERED_DEVICE): vol.In(devices + [CUSTOM_DEVICE])}
{
vol.Required(CONF_FRIENDLY_NAME): str,
vol.Required(CONF_HOST): str,
vol.Required(CONF_LOCAL_KEY): str,
vol.Required(CONF_PROTOCOL_VERSION, default="3.3"): vol.In(["3.1", "3.3"]),
vol.Required(
CONF_ENTITIES, description={"suggested_value": entity_names}
): cv.multi_select(entity_names),
}
)


Expand Down Expand Up @@ -213,8 +227,7 @@ async def async_step_user(self, user_input=None):
errors = {}
if user_input is not None:
if user_input[DISCOVERED_DEVICE] != CUSTOM_DEVICE:
device = user_input[DISCOVERED_DEVICE].split(" ")[0]
self.selected_device = self.devices[device]
self.selected_device = user_input[DISCOVERED_DEVICE].split(" ")[0]
return await self.async_step_basic_info()

# Use cache if available or fallback to manual discovery
Expand All @@ -241,22 +254,23 @@ async def async_step_user(self, user_input=None):
}

return self.async_show_form(
step_id="user", errors=errors, data_schema=user_schema(self.devices)
step_id="user",
errors=errors,
data_schema=user_schema(self.devices, self._async_current_entries()),
)

async def async_step_basic_info(self, user_input=None):
"""Handle input of basic info."""
errors = {}
if user_input is not None:
await self.async_set_unique_id(user_input[CONF_DEVICE_ID])
self._abort_if_unique_id_configured()

try:
self.basic_info = user_input
if self.selected_device is not None:
self.basic_info[CONF_PRODUCT_KEY] = self.selected_device[
"productKey"
]
self.basic_info[CONF_PRODUCT_KEY] = self.devices[
self.selected_device
]["productKey"]
self.dps_strings = await validate_input(self.hass, user_input)
return await self.async_step_pick_entity_type()
except CannotConnect:
Expand All @@ -269,12 +283,23 @@ async def async_step_basic_info(self, user_input=None):
_LOGGER.exception("Unexpected exception")
errors["base"] = "unknown"

# If selected device exists as a config entry, load config from it
if self.selected_device in self._async_current_ids():
entry = async_config_entry_by_device_id(self.hass, self.selected_device)
await self.async_set_unique_id(entry.data[CONF_DEVICE_ID])
self.basic_info = entry.data.copy()
self.dps_strings = self.basic_info.pop(CONF_DPS_STRINGS).copy()
self.entities = self.basic_info.pop(CONF_ENTITIES).copy()
return await self.async_step_pick_entity_type()

# Insert default values from discovery if present
defaults = {}
defaults.update(user_input or {})
if self.selected_device is not None:
defaults[CONF_HOST] = self.selected_device.get("ip")
defaults[CONF_DEVICE_ID] = self.selected_device.get("gwId")
defaults[CONF_PROTOCOL_VERSION] = self.selected_device.get("version")
device = self.devices[self.selected_device]
defaults[CONF_HOST] = device.get("ip")
defaults[CONF_DEVICE_ID] = device.get("gwId")
defaults[CONF_PROTOCOL_VERSION] = device.get("version")

return self.async_show_form(
step_id="basic_info",
Expand All @@ -291,6 +316,10 @@ async def async_step_pick_entity_type(self, user_input=None):
CONF_DPS_STRINGS: self.dps_strings,
CONF_ENTITIES: self.entities,
}
entry = async_config_entry_by_device_id(self.hass, self.unique_id)
if entry:
self.hass.config_entries.async_update_entry(entry, data=config)
return self.async_abort(reason="device_updated")
return self.async_create_entry(
title=config[CONF_FRIENDLY_NAME], data=config
)
Expand All @@ -313,7 +342,8 @@ async def async_step_add_entity(self, user_input=None):
errors = {}
if user_input is not None:
already_configured = any(
switch[CONF_ID] == user_input[CONF_ID] for switch in self.entities
switch[CONF_ID] == int(user_input[CONF_ID].split(" ")[0])
for switch in self.entities
)
if not already_configured:
user_input[CONF_PLATFORM] = self.platform
Expand Down Expand Up @@ -352,21 +382,34 @@ async def async_step_init(self, user_input=None):
"""Manage basic options."""
device_id = self.config_entry.data[CONF_DEVICE_ID]
if user_input is not None:
self.data = {
CONF_DEVICE_ID: device_id,
CONF_DPS_STRINGS: self.dps_strings,
CONF_ENTITIES: [],
}
self.data.update(user_input)
return await self.async_step_entity()
self.data = user_input.copy()
self.data.update(
{
CONF_DEVICE_ID: device_id,
CONF_DPS_STRINGS: self.dps_strings,
CONF_ENTITIES: [],
}
)
if len(user_input[CONF_ENTITIES]) > 0:
entity_ids = [
int(entity.split(" ")[0]) for entity in user_input[CONF_ENTITIES]
]
self.entities = [
entity
for entity in self.config_entry.data[CONF_ENTITIES]
if entity[CONF_ID] in entity_ids
]
return await self.async_step_entity()

# Not supported for YAML imports
if self.config_entry.source == config_entries.SOURCE_IMPORT:
return await self.async_step_yaml_import()

return self.async_show_form(
step_id="init",
data_schema=schema_defaults(OPTIONS_SCHEMA, **self.config_entry.data),
data_schema=schema_defaults(
options_schema(self.entities), **self.config_entry.data
),
description_placeholders={"device_id": device_id},
)

Expand Down
4 changes: 2 additions & 2 deletions custom_components/localtuya/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async def start(self):
lambda: self, local_addr=("0.0.0.0", 6667)
)

self.listeners = await asyncio.gather(listener, encrypted_listener)
self._listeners = await asyncio.gather(listener, encrypted_listener)
_LOGGER.debug("Listening to broadcasts on UDP port 6666 and 6667")

def close(self):
Expand All @@ -72,7 +72,7 @@ def datagram_received(self, data, addr):
def device_found(self, device):
"""Discover a new device."""
if device.get("ip") not in self.devices:
self.devices[device.get("ip")] = device
self.devices[device.get("gwId")] = device
_LOGGER.debug("Discovered device: %s", device)

if self._callback:
Expand Down
6 changes: 4 additions & 2 deletions custom_components/localtuya/translations/en.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
{
"config": {
"abort": {
"already_configured": "Device has already been configured."
"already_configured": "Device has already been configured.",
"device_updated": "Device configuration has been updated!"
},
"error": {
"cannot_connect": "Cannot connect to device. Verify that address is correct and try again.",
Expand Down Expand Up @@ -87,7 +88,8 @@
"friendly_name": "Friendly Name",
"host": "Host",
"local_key": "Local key",
"protocol_version": "Protocol Version"
"protocol_version": "Protocol Version",
"entities": "Entities (uncheck an entity to remove it)"
}
},
"entity": {
Expand Down

0 comments on commit 7e2f55a

Please sign in to comment.