Skip to content

Commit

Permalink
Update device locations when integration loads
Browse files Browse the repository at this point in the history
- Location syncing is disabled by default
- HA areas area created as necessary
  • Loading branch information
jason0x43 committed Jan 3, 2025
1 parent a4bb90a commit 8abac40
Show file tree
Hide file tree
Showing 11 changed files with 347 additions and 265 deletions.
75 changes: 42 additions & 33 deletions custom_components/hubitat/config_flow.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Config flow for Hubitat integration."""

import logging
from collections.abc import Awaitable
from copy import deepcopy
from typing import Any, Awaitable, Callable, cast
from typing import Any, Callable, TypedDict, override

import voluptuous as vol
from voluptuous.schema_builder import Schema
Expand All @@ -12,6 +13,7 @@
CONN_CLASS_LOCAL_PUSH,
ConfigEntry,
ConfigFlow,
ConfigFlowResult,
OptionsFlow,
OptionsFlowWithConfigEntry,
)
Expand All @@ -20,14 +22,6 @@
from homeassistant.helpers import device_registry
from homeassistant.helpers.device_registry import DeviceEntry

try:
from homeassistant.config_entries import ConfigFlowResult # type: ignore
except Exception:
from homeassistant.config_entries import (
FlowResult as ConfigFlowResult, # type: ignore
)


from .const import (
DOMAIN,
H_CONF_APP_ID,
Expand All @@ -38,6 +32,7 @@
H_CONF_SERVER_SSL_CERT,
H_CONF_SERVER_SSL_KEY,
H_CONF_SERVER_URL,
H_CONF_SYNC_AREAS,
TEMP_C,
TEMP_F,
ConfigStep,
Expand Down Expand Up @@ -66,29 +61,30 @@
vol.Optional(H_CONF_SERVER_PORT): int,
vol.Optional(H_CONF_SERVER_SSL_CERT): str,
vol.Optional(H_CONF_SERVER_SSL_KEY): str,
vol.Optional(
CONF_TEMPERATURE_UNIT, default=cast(vol.Undefined, TEMP_F)
): vol.In([TEMP_F, TEMP_C]),
vol.Optional(CONF_TEMPERATURE_UNIT, default=TEMP_F): vol.In([TEMP_F, TEMP_C]),
vol.Optional(H_CONF_SYNC_AREAS, default=False): bool,
}
)


class HubitatConfigFlow(ConfigFlow, domain=DOMAIN): # type: ignore
"""Handle a config flow for Hubitat."""

VERSION = 1
CONNECTION_CLASS = CONN_CLASS_LOCAL_PUSH
VERSION: int = 1
CONNECTION_CLASS: str = CONN_CLASS_LOCAL_PUSH

hub: HubitatHub | None = None
device_schema: Schema
device_schema: Schema | None = None

@staticmethod
@callback
@override
def async_get_options_flow(config_entry: ConfigEntry) -> OptionsFlow:
return HubitatOptionsFlow(config_entry)

# TODO: remove the 'type: ignore' when were not falling back on
# FlowResult
@override
async def async_step_user( # type: ignore
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
Expand All @@ -101,7 +97,7 @@ async def async_step_user( # type: ignore
entry_data = deepcopy(user_input)
self.hub = info["hub"]

placeholders: dict[str, str] = {}
placeholders: dict[str, Any] = {}
for key in user_input:
if user_input[key] is not None and key in placeholders:
placeholders[key] = user_input[key]
Expand Down Expand Up @@ -149,14 +145,14 @@ class HubitatOptionsFlow(OptionsFlowWithConfigEntry):

hub: HubitatHub | None = None
overrides: dict[str, str] = {}
should_remove_devices = False
should_remove_devices: bool = False

def __init__(self, config_entry: ConfigEntry):
"""Initialize an options flow."""
super().__init__(config_entry)

async def async_step_init(
self, user_input: dict[str, str] | None = None
self, _user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Handle integration options."""
return await self.async_step_user()
Expand All @@ -180,6 +176,7 @@ async def async_step_user(
H_CONF_SERVER_URL: user_input.get(H_CONF_SERVER_URL),
H_CONF_SERVER_SSL_CERT: user_input.get(H_CONF_SERVER_SSL_CERT),
H_CONF_SERVER_SSL_KEY: user_input.get(H_CONF_SERVER_SSL_KEY),
H_CONF_SYNC_AREAS: user_input.get(H_CONF_SYNC_AREAS),
}

info = await _validate_input(check_input)
Expand All @@ -195,6 +192,7 @@ async def async_step_user(
H_CONF_SERVER_SSL_KEY
)
self.options[CONF_TEMPERATURE_UNIT] = user_input[CONF_TEMPERATURE_UNIT]
self.options[H_CONF_SYNC_AREAS] = user_input.get(H_CONF_SYNC_AREAS)

_LOGGER.debug("Moving to device removal step")
return await self.async_step_remove_devices()
Expand Down Expand Up @@ -273,15 +271,20 @@ async def async_step_user(
): str,
vol.Optional(
CONF_TEMPERATURE_UNIT,
default=cast(
vol.Undefined,
entry.options.get(
CONF_TEMPERATURE_UNIT,
entry.data.get(CONF_TEMPERATURE_UNIT),
),
default=entry.options.get(
CONF_TEMPERATURE_UNIT,
entry.data.get(CONF_TEMPERATURE_UNIT),
)
or TEMP_F,
): vol.In([TEMP_F, TEMP_C]),
vol.Optional(
H_CONF_SYNC_AREAS,
default=entry.options.get(
H_CONF_SYNC_AREAS,
entry.data.get(H_CONF_SYNC_AREAS),
)
or False,
): bool,
}
),
errors=form_errors,
Expand Down Expand Up @@ -309,14 +312,13 @@ async def async_step_remove_devices(

device_schema = vol.Schema(
{
vol.Optional(
H_CONF_DEVICES, default=cast(vol.Undefined, [])
): cv.multi_select(device_map),
vol.Optional(H_CONF_DEVICES, default=[]): cv.multi_select(device_map),
}
)

if user_input is not None:
ids = [id for id in user_input[H_CONF_DEVICES]]
conf_devs: list[str] = user_input[H_CONF_DEVICES]
ids = [id for id in conf_devs]
_remove_devices(self.hass, ids)
return await self.async_step_override_lights()

Expand Down Expand Up @@ -385,7 +387,9 @@ async def _async_step_override_type(
# update will be triggered if devices are added or removed
self.options[H_CONF_DEVICE_LIST] = sorted([id for id in devices])

existing_overrides = self.options.get(H_CONF_DEVICE_TYPE_OVERRIDES)
existing_overrides: dict[str, str] | None = self.options.get(
H_CONF_DEVICE_TYPE_OVERRIDES
)
default_value = []

possible_overrides = {
Expand All @@ -401,9 +405,9 @@ async def _async_step_override_type(

device_schema = vol.Schema(
{
vol.Optional(
H_CONF_DEVICES, default=cast(vol.Undefined, default_value)
): cv.multi_select(possible_overrides)
vol.Optional(H_CONF_DEVICES, default=default_value): cv.multi_select(
possible_overrides
)
}
)

Expand Down Expand Up @@ -454,7 +458,12 @@ def _remove_devices(hass: HomeAssistant, device_ids: list[str]) -> None:
dreg.async_remove_device(id)


async def _validate_input(user_input: dict[str, Any]) -> dict[str, Any]:
class ValidatedInput(TypedDict):
label: str
hub: HubitatHub


async def _validate_input(user_input: dict[str, Any]) -> ValidatedInput:
"""Validate that the user input can create a working connection."""

# data has the keys from CONFIG_SCHEMA with values provided by the user.
Expand Down
2 changes: 1 addition & 1 deletion custom_components/hubitat/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ class EventName(StrEnum):
H_CONF_DEVICE_TYPE_OVERRIDES = "device_type_overrides"
H_CONF_SERVER_PORT = "server_port"
H_CONF_SERVER_URL = "server_url"
H_CONF_USE_SERVER_URL = "use_server_url"
H_CONF_SERVER_SSL_CERT = "server_ssl_cert"
H_CONF_SERVER_SSL_KEY = "server_ssl_key"
H_CONF_SYNC_AREAS = "sync_areas"
H_CONF_BUTTON = "button"
H_CONF_DOUBLE_TAPPED = "double_tapped"
H_CONF_HELD = "held"
Expand Down
8 changes: 2 additions & 6 deletions custom_components/hubitat/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .hub import Hub
from .hubitatmaker import Device, DeviceAttribute, Event
from .types import Removable, UpdateableEntity
from .util import get_hub_device_id
from .util import get_device_identifiers, get_hub_device_id

_LOGGER = getLogger(__name__)

Expand Down Expand Up @@ -198,12 +198,8 @@ def __repr__(self) -> str:

def get_device_info(hub: Hub, device: Device) -> device_registry.DeviceInfo:
"""Return the device info."""
dev_identifier = device.id
if hub.id != device.id:
dev_identifier = f"{hub.id}:{device.id}"

info: device_registry.DeviceInfo = {
"identifiers": {(DOMAIN, dev_identifier)},
"identifiers": get_device_identifiers(hub.id, device.id)
}

# if this entity's device isn't the hub, link it to the hub
Expand Down
62 changes: 52 additions & 10 deletions custom_components/hubitat/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
UnitOfTemperature,
)
from homeassistant.core import CALLBACK_TYPE, HomeAssistant
from homeassistant.helpers import device_registry
from homeassistant.helpers import area_registry, device_registry
from homeassistant.helpers.device_registry import DeviceEntry

try:
Expand All @@ -41,6 +41,7 @@ class DiscoveryKey:
H_CONF_SERVER_SSL_CERT,
H_CONF_SERVER_SSL_KEY,
H_CONF_SERVER_URL,
H_CONF_SYNC_AREAS,
PLATFORMS,
TEMP_F,
TRIGGER_CAPABILITIES,
Expand All @@ -51,7 +52,7 @@ class DiscoveryKey:
Hub as HubitatHub,
)
from .types import Removable, UpdateableEntity
from .util import get_hub_device_id, get_hub_short_id
from .util import get_device_identifiers, get_hub_device_id, get_hub_short_id

_LOGGER = getLogger(__name__)

Expand Down Expand Up @@ -351,8 +352,6 @@ async def create(hass: HomeAssistant, entry: ConfigEntry, index: int) -> "Hub":
hub = Hub(hass, entry, index, hubitat_hub, device)
hass.data[DOMAIN][entry.entry_id] = hub

_LOGGER.debug(f"Set {entry.entry_id} on hass.data")

# Add a listener for every device exported by the hub. The listener
# will re-export the Hubitat event as a hubitat_event in HA if it
# matches a trigger condition.
Expand All @@ -370,6 +369,12 @@ async def create(hass: HomeAssistant, entry: ConfigEntry, index: int) -> "Hub":

_LOGGER.debug("Registered platforms")

should_update_rooms = entry.options.get(
H_CONF_SYNC_AREAS, entry.data.get(H_CONF_SYNC_AREAS)
)
if should_update_rooms:
_update_device_rooms(hub, hass)

# Create an entity for the Hubitat hub with basic hub information
hass.states.async_set(
hub.entity_id,
Expand Down Expand Up @@ -616,7 +621,7 @@ def _update_device_ids(hub_id: str, hass: HomeAssistant) -> None:
else:
_LOGGER.warning(f"Device identifier in unknown format: {id_set}")

# Update any devices with 3-part identifiers to use 2-part identifiers in
# Update any devices with 3-part identifiers to use 2-part identifiers in
# the new format
for id in new_devs:
new_dev = new_devs[id]
Expand Down Expand Up @@ -654,11 +659,48 @@ def _update_device_ids(hub_id: str, hass: HomeAssistant) -> None:
)


# hub_typecheck: Hub
# """A hub instance that will only exist during development."""
#
# device_typecheck: Device
# """A device instance that will only exist during development."""
def _update_device_rooms(hub: Hub, hass: HomeAssistant) -> None:
"""Update the area ID of devices in Home Assistant to match the room from
Hubitat
"""

_LOGGER.debug("Synchronizing device rooms...")

dreg = device_registry.async_get(hass)
all_devices = [dreg.devices[id] for id in dreg.devices]
hubitat_devices: list[DeviceEntry] = []
for dev in all_devices:
ids = list(dev.identifiers)
if len(ids) != 1:
continue
id_set = ids[0]
if id_set[0] == DOMAIN and dev.name != HUB_NAME:
hubitat_devices.append(dev)

areg = area_registry.async_get(hass)

for device_id in hub.devices:
# if the device area from Hubitat is defined and differs from
# Home Assistant, update the device area ID
device = hub.devices[device_id]
identifiers = get_device_identifiers(hub.id, device.id)
hass_device = dreg.async_get_device(identifiers)

if not hass_device:
_LOGGER.debug(
"Skipping room check for %s because no HA device", device.name
)
continue

if device.room:
area = areg.async_get_or_create(device.room)
if hass_device.area_id != area.id:
_ = dreg.async_update_device(hass_device.id, area_id=area.id)
_LOGGER.debug("Updated location of %s to %s", device.name, area.id)
elif hass_device.area_id:
_ = dreg.async_clear_area_id(hass_device.id)
_LOGGER.debug("Cleared location of %s", device.name)


if TYPE_CHECKING:
test_hass = HomeAssistant("")
Expand Down
Loading

0 comments on commit 8abac40

Please sign in to comment.