From 7e2f55a5c55b6dca9a8940faa670efea13d95d45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pierre=20St=C3=A5hl?= Date: Fri, 6 Nov 2020 08:01:03 +0100 Subject: [PATCH] Support adding and removing entities to devices New entities can be added via + button under Integrations (like when adding a new device) and entities can be removed from Options. --- custom_components/localtuya/__init__.py | 34 ++++-- custom_components/localtuya/common.py | 10 ++ custom_components/localtuya/config_flow.py | 105 ++++++++++++------ custom_components/localtuya/discovery.py | 4 +- .../localtuya/translations/en.json | 6 +- 5 files changed, 113 insertions(+), 46 deletions(-) diff --git a/custom_components/localtuya/__init__.py b/custom_components/localtuya/__init__.py index b3b2ae6e9..4f8874d79 100644 --- a/custom_components/localtuya/__init__.py +++ b/custom_components/localtuya/__init__.py @@ -59,6 +59,7 @@ from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry from homeassistant.const import ( + CONF_ID, CONF_DEVICE_ID, CONF_ENTITIES, CONF_HOST, @@ -68,8 +69,9 @@ ) from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.reload import async_integration_yaml_config +import homeassistant.helpers.entity_registry as er -from .common import TuyaDevice +from .common import TuyaDevice, async_config_entry_by_device_id from .config_flow import config_schema from .const import CONF_PRODUCT_KEY, DATA_DISCOVERY, DOMAIN, TUYA_DEVICE from .discovery import TuyaDiscovery @@ -117,14 +119,6 @@ async def _handle_reload(service): await asyncio.gather(*reload_tasks) - def _entry_by_device_id(device_id): - """Look up config entry by device id.""" - current_entries = hass.config_entries.async_entries(DOMAIN) - for entry in current_entries: - if entry.data[CONF_DEVICE_ID] == device_id: - return entry - return None - def _device_discovered(device): """Update address of device if it has changed.""" device_ip = device["ip"] @@ -133,7 +127,7 @@ def _device_discovered(device): # If device is not in cache, check if a config entry exists if device_id not in device_cache: - entry = _entry_by_device_id(device_id) + entry = async_config_entry_by_device_id(hass, device_id) if entry: # Save address from config entry in cache to trigger # potential update below @@ -142,7 +136,7 @@ def _device_discovered(device): if device_id not in device_cache: return - entry = _entry_by_device_id(device_id) + entry = async_config_entry_by_device_id(hass, device_id) if entry is None: return @@ -211,6 +205,8 @@ async def setup_entities(): ) device.connect() + await async_remove_orphan_entities(hass, entry) + hass.async_create_task(setup_entities()) return True @@ -240,3 +236,19 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry): async def update_listener(hass, config_entry): """Update listener.""" await hass.config_entries.async_reload(config_entry.entry_id) + + +async def async_remove_orphan_entities(hass, entry): + """Remove entities associated with config entry that has been removed.""" + ent_reg = await er.async_get_registry(hass) + entities = { + int(ent.unique_id.split("_")[-1]): ent.entity_id + for ent in er.async_entries_for_config_entry(ent_reg, entry.entry_id) + } + + for entity in entry.data[CONF_ENTITIES]: + if entity[CONF_ID] in entities: + del entities[entity[CONF_ID]] + + for entity_id in entities.values(): + ent_reg.async_remove(entity_id) diff --git a/custom_components/localtuya/common.py b/custom_components/localtuya/common.py index c6653c4bd..362cae112 100644 --- a/custom_components/localtuya/common.py +++ b/custom_components/localtuya/common.py @@ -96,6 +96,16 @@ def get_entity_config(config_entry, dp_id): raise Exception(f"missing entity config for id {dp_id}") +@callback +def async_config_entry_by_device_id(hass, device_id): + """Look up config entry by device id.""" + current_entries = hass.config_entries.async_entries(DOMAIN) + for entry in current_entries: + if entry.data[CONF_DEVICE_ID] == device_id: + return entry + return None + + class TuyaDevice(pytuya.TuyaListener, pytuya.ContextualLogger): """Cache wrapper for pytuya.TuyaInterface.""" diff --git a/custom_components/localtuya/config_flow.py b/custom_components/localtuya/config_flow.py index fc0c01e29..a4984d4aa 100644 --- a/custom_components/localtuya/config_flow.py +++ b/custom_components/localtuya/config_flow.py @@ -16,7 +16,7 @@ ) from homeassistant.core import callback -from . import pytuya +from .common import pytuya, async_config_entry_by_device_id from .const import CONF_DPS_STRINGS # pylint: disable=unused-import from .const import ( CONF_LOCAL_KEY, @@ -46,14 +46,6 @@ } ) -OPTIONS_SCHEMA = vol.Schema( - { - vol.Required(CONF_FRIENDLY_NAME): str, - vol.Required(CONF_HOST): str, - vol.Required(CONF_LOCAL_KEY): str, - vol.Required(CONF_PROTOCOL_VERSION, default="3.3"): vol.In(["3.1", "3.3"]), - } -) DEVICE_SCHEMA = vol.Schema( { @@ -70,11 +62,33 @@ ) -def user_schema(devices): +def user_schema(devices, entries): """Create schema for user step.""" - devices = [f"{ip} ({dev['gwId']})" for ip, dev in devices.items()] + devices = {dev_id: dev["ip"] for dev_id, dev in devices.items()} + devices.update( + {ent.data[CONF_DEVICE_ID]: ent.data[CONF_FRIENDLY_NAME] for ent in entries} + ) + device_list = [f"{key} ({value})" for key, value in devices.items()] + return vol.Schema( + {vol.Required(DISCOVERED_DEVICE): vol.In(device_list + [CUSTOM_DEVICE])} + ) + + +def options_schema(entities): + """Create schema for options.""" + entity_names = [ + f"{entity[CONF_ID]} {entity[CONF_FRIENDLY_NAME]}" for entity in entities + ] return vol.Schema( - {vol.Required(DISCOVERED_DEVICE): vol.In(devices + [CUSTOM_DEVICE])} + { + vol.Required(CONF_FRIENDLY_NAME): str, + vol.Required(CONF_HOST): str, + vol.Required(CONF_LOCAL_KEY): str, + vol.Required(CONF_PROTOCOL_VERSION, default="3.3"): vol.In(["3.1", "3.3"]), + vol.Required( + CONF_ENTITIES, description={"suggested_value": entity_names} + ): cv.multi_select(entity_names), + } ) @@ -213,8 +227,7 @@ async def async_step_user(self, user_input=None): errors = {} if user_input is not None: if user_input[DISCOVERED_DEVICE] != CUSTOM_DEVICE: - device = user_input[DISCOVERED_DEVICE].split(" ")[0] - self.selected_device = self.devices[device] + self.selected_device = user_input[DISCOVERED_DEVICE].split(" ")[0] return await self.async_step_basic_info() # Use cache if available or fallback to manual discovery @@ -241,7 +254,9 @@ async def async_step_user(self, user_input=None): } return self.async_show_form( - step_id="user", errors=errors, data_schema=user_schema(self.devices) + step_id="user", + errors=errors, + data_schema=user_schema(self.devices, self._async_current_entries()), ) async def async_step_basic_info(self, user_input=None): @@ -249,14 +264,13 @@ async def async_step_basic_info(self, user_input=None): errors = {} if user_input is not None: await self.async_set_unique_id(user_input[CONF_DEVICE_ID]) - self._abort_if_unique_id_configured() try: self.basic_info = user_input if self.selected_device is not None: - self.basic_info[CONF_PRODUCT_KEY] = self.selected_device[ - "productKey" - ] + self.basic_info[CONF_PRODUCT_KEY] = self.devices[ + self.selected_device + ]["productKey"] self.dps_strings = await validate_input(self.hass, user_input) return await self.async_step_pick_entity_type() except CannotConnect: @@ -269,12 +283,23 @@ async def async_step_basic_info(self, user_input=None): _LOGGER.exception("Unexpected exception") errors["base"] = "unknown" + # If selected device exists as a config entry, load config from it + if self.selected_device in self._async_current_ids(): + entry = async_config_entry_by_device_id(self.hass, self.selected_device) + await self.async_set_unique_id(entry.data[CONF_DEVICE_ID]) + self.basic_info = entry.data.copy() + self.dps_strings = self.basic_info.pop(CONF_DPS_STRINGS).copy() + self.entities = self.basic_info.pop(CONF_ENTITIES).copy() + return await self.async_step_pick_entity_type() + + # Insert default values from discovery if present defaults = {} defaults.update(user_input or {}) if self.selected_device is not None: - defaults[CONF_HOST] = self.selected_device.get("ip") - defaults[CONF_DEVICE_ID] = self.selected_device.get("gwId") - defaults[CONF_PROTOCOL_VERSION] = self.selected_device.get("version") + device = self.devices[self.selected_device] + defaults[CONF_HOST] = device.get("ip") + defaults[CONF_DEVICE_ID] = device.get("gwId") + defaults[CONF_PROTOCOL_VERSION] = device.get("version") return self.async_show_form( step_id="basic_info", @@ -291,6 +316,10 @@ async def async_step_pick_entity_type(self, user_input=None): CONF_DPS_STRINGS: self.dps_strings, CONF_ENTITIES: self.entities, } + entry = async_config_entry_by_device_id(self.hass, self.unique_id) + if entry: + self.hass.config_entries.async_update_entry(entry, data=config) + return self.async_abort(reason="device_updated") return self.async_create_entry( title=config[CONF_FRIENDLY_NAME], data=config ) @@ -313,7 +342,8 @@ async def async_step_add_entity(self, user_input=None): errors = {} if user_input is not None: already_configured = any( - switch[CONF_ID] == user_input[CONF_ID] for switch in self.entities + switch[CONF_ID] == int(user_input[CONF_ID].split(" ")[0]) + for switch in self.entities ) if not already_configured: user_input[CONF_PLATFORM] = self.platform @@ -352,13 +382,24 @@ async def async_step_init(self, user_input=None): """Manage basic options.""" device_id = self.config_entry.data[CONF_DEVICE_ID] if user_input is not None: - self.data = { - CONF_DEVICE_ID: device_id, - CONF_DPS_STRINGS: self.dps_strings, - CONF_ENTITIES: [], - } - self.data.update(user_input) - return await self.async_step_entity() + self.data = user_input.copy() + self.data.update( + { + CONF_DEVICE_ID: device_id, + CONF_DPS_STRINGS: self.dps_strings, + CONF_ENTITIES: [], + } + ) + if len(user_input[CONF_ENTITIES]) > 0: + entity_ids = [ + int(entity.split(" ")[0]) for entity in user_input[CONF_ENTITIES] + ] + self.entities = [ + entity + for entity in self.config_entry.data[CONF_ENTITIES] + if entity[CONF_ID] in entity_ids + ] + return await self.async_step_entity() # Not supported for YAML imports if self.config_entry.source == config_entries.SOURCE_IMPORT: @@ -366,7 +407,9 @@ async def async_step_init(self, user_input=None): return self.async_show_form( step_id="init", - data_schema=schema_defaults(OPTIONS_SCHEMA, **self.config_entry.data), + data_schema=schema_defaults( + options_schema(self.entities), **self.config_entry.data + ), description_placeholders={"device_id": device_id}, ) diff --git a/custom_components/localtuya/discovery.py b/custom_components/localtuya/discovery.py index 93a41bdb9..311c10530 100644 --- a/custom_components/localtuya/discovery.py +++ b/custom_components/localtuya/discovery.py @@ -49,7 +49,7 @@ async def start(self): lambda: self, local_addr=("0.0.0.0", 6667) ) - self.listeners = await asyncio.gather(listener, encrypted_listener) + self._listeners = await asyncio.gather(listener, encrypted_listener) _LOGGER.debug("Listening to broadcasts on UDP port 6666 and 6667") def close(self): @@ -72,7 +72,7 @@ def datagram_received(self, data, addr): def device_found(self, device): """Discover a new device.""" if device.get("ip") not in self.devices: - self.devices[device.get("ip")] = device + self.devices[device.get("gwId")] = device _LOGGER.debug("Discovered device: %s", device) if self._callback: diff --git a/custom_components/localtuya/translations/en.json b/custom_components/localtuya/translations/en.json index a7f361b07..b552bc03d 100644 --- a/custom_components/localtuya/translations/en.json +++ b/custom_components/localtuya/translations/en.json @@ -1,7 +1,8 @@ { "config": { "abort": { - "already_configured": "Device has already been configured." + "already_configured": "Device has already been configured.", + "device_updated": "Device configuration has been updated!" }, "error": { "cannot_connect": "Cannot connect to device. Verify that address is correct and try again.", @@ -87,7 +88,8 @@ "friendly_name": "Friendly Name", "host": "Host", "local_key": "Local key", - "protocol_version": "Protocol Version" + "protocol_version": "Protocol Version", + "entities": "Entities (uncheck an entity to remove it)" } }, "entity": {