From d812507aebbd0156a07fb700ac71fa69cf61bec9 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 26 Feb 2024 21:08:28 -1000 Subject: [PATCH] Refactor eafm to avoid creating entities in the coordinator update (#111601) --- homeassistant/components/eafm/__init__.py | 44 +++++++++++++- homeassistant/components/eafm/sensor.py | 71 +++++++---------------- tests/components/eafm/conftest.py | 2 +- tests/components/eafm/test_sensor.py | 13 +++++ 4 files changed, 79 insertions(+), 51 deletions(-) diff --git a/homeassistant/components/eafm/__init__.py b/homeassistant/components/eafm/__init__.py index a57558ff1cc7a..d18553b6ee61e 100644 --- a/homeassistant/components/eafm/__init__.py +++ b/homeassistant/components/eafm/__init__.py @@ -1,16 +1,58 @@ """UK Environment Agency Flood Monitoring Integration.""" +import asyncio +from datetime import timedelta +import logging +from typing import Any + +from aioeafm import get_station + from homeassistant.config_entries import ConfigEntry from homeassistant.const import Platform from homeassistant.core import HomeAssistant +from homeassistant.helpers.aiohttp_client import async_get_clientsession +from homeassistant.helpers.update_coordinator import DataUpdateCoordinator from .const import DOMAIN PLATFORMS = [Platform.SENSOR] +_LOGGER = logging.getLogger(__name__) + + +def get_measures(station_data): + """Force measure key to always be a list.""" + if "measures" not in station_data: + return [] + if isinstance(station_data["measures"], dict): + return [station_data["measures"]] + return station_data["measures"] + async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up flood monitoring sensors for this config entry.""" - hass.data.setdefault(DOMAIN, {}) + station_key = entry.data["station"] + session = async_get_clientsession(hass=hass) + + async def _async_update_data() -> dict[str, dict[str, Any]]: + # DataUpdateCoordinator will handle aiohttp ClientErrors and timeouts + async with asyncio.timeout(30): + data = await get_station(session, station_key) + + measures = get_measures(data) + # Turn data.measures into a dict rather than a list so easier for entities to + # find themselves. + data["measures"] = {measure["@id"]: measure for measure in measures} + return data + + coordinator = DataUpdateCoordinator[dict[str, dict[str, Any]]]( + hass, + _LOGGER, + name="sensor", + update_method=_async_update_data, + update_interval=timedelta(seconds=15 * 60), + ) + await coordinator.async_config_entry_first_refresh() + hass.data.setdefault(DOMAIN, {})[entry.entry_id] = coordinator await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) return True diff --git a/homeassistant/components/eafm/sensor.py b/homeassistant/components/eafm/sensor.py index 2c7f8456a72b3..297f4d6d2c8e5 100644 --- a/homeassistant/components/eafm/sensor.py +++ b/homeassistant/components/eafm/sensor.py @@ -1,15 +1,11 @@ """Support for gauges from flood monitoring API.""" -import asyncio -from datetime import timedelta -import logging -from aioeafm import get_station +from typing import Any from homeassistant.components.sensor import SensorEntity, SensorStateClass from homeassistant.config_entries import ConfigEntry from homeassistant.const import UnitOfLength -from homeassistant.core import HomeAssistant -from homeassistant.helpers.aiohttp_client import async_get_clientsession +from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.update_coordinator import ( @@ -19,72 +15,49 @@ from .const import DOMAIN -_LOGGER = logging.getLogger(__name__) - UNIT_MAPPING = { "http://qudt.org/1.1/vocab/unit#Meter": UnitOfLength.METERS, } -def get_measures(station_data): - """Force measure key to always be a list.""" - if "measures" not in station_data: - return [] - if isinstance(station_data["measures"], dict): - return [station_data["measures"]] - return station_data["measures"] - - async def async_setup_entry( hass: HomeAssistant, config_entry: ConfigEntry, async_add_entities: AddEntitiesCallback, ) -> None: """Set up UK Flood Monitoring Sensors.""" - station_key = config_entry.data["station"] - session = async_get_clientsession(hass=hass) - - measurements = set() - - async def async_update_data(): - # DataUpdateCoordinator will handle aiohttp ClientErrors and timeouts - async with asyncio.timeout(30): - data = await get_station(session, station_key) - - measures = get_measures(data) - entities = [] - + coordinator: DataUpdateCoordinator = hass.data[DOMAIN][config_entry.entry_id] + created_entities: set[str] = set() + + @callback + def _async_create_new_entities(): + """Create new entities.""" + if not coordinator.last_update_success: + return + measures: dict[str, dict[str, Any]] = coordinator.data["measures"] + entities: list[Measurement] = [] # Look to see if payload contains new measures - for measure in measures: - if measure["@id"] in measurements: + for key, data in measures.items(): + if key in created_entities: continue - if "latestReading" not in measure: + if "latestReading" not in data: # Don't create a sensor entity for a gauge that isn't available continue - entities.append(Measurement(hass.data[DOMAIN][station_key], measure["@id"])) - measurements.add(measure["@id"]) + entities.append(Measurement(coordinator, key)) + created_entities.add(key) async_add_entities(entities) - # Turn data.measures into a dict rather than a list so easier for entities to - # find themselves. - data["measures"] = {measure["@id"]: measure for measure in measures} + _async_create_new_entities() - return data - - hass.data[DOMAIN][station_key] = coordinator = DataUpdateCoordinator( - hass, - _LOGGER, - name="sensor", - update_method=async_update_data, - update_interval=timedelta(seconds=15 * 60), + # Subscribe to the coordinator to create new entities + # when the coordinator updates + config_entry.async_on_unload( + coordinator.async_add_listener(_async_create_new_entities) ) - # Fetch initial data so we have data when entities subscribe - await coordinator.async_refresh() - class Measurement(CoordinatorEntity, SensorEntity): """A gauge at a flood monitoring station.""" diff --git a/tests/components/eafm/conftest.py b/tests/components/eafm/conftest.py index b5deba4b9243a..3b060563a30f9 100644 --- a/tests/components/eafm/conftest.py +++ b/tests/components/eafm/conftest.py @@ -15,5 +15,5 @@ def mock_get_stations(): @pytest.fixture def mock_get_station(): """Mock aioeafm.get_station.""" - with patch("homeassistant.components.eafm.sensor.get_station") as patched: + with patch("homeassistant.components.eafm.get_station") as patched: yield patched diff --git a/tests/components/eafm/test_sensor.py b/tests/components/eafm/test_sensor.py index f02b800a18f65..04659a79d1c0d 100644 --- a/tests/components/eafm/test_sensor.py +++ b/tests/components/eafm/test_sensor.py @@ -339,6 +339,19 @@ async def test_ignore_no_latest_reading(hass: HomeAssistant, mock_get_station) - assert state is None +async def test_no_measures(hass: HomeAssistant, mock_get_station) -> None: + """Test no measures in the data.""" + await async_setup_test_fixture( + hass, + mock_get_station, + { + "label": "My station", + }, + ) + + assert hass.states.async_entity_ids_count() == 0 + + async def test_mark_existing_as_unavailable_if_no_latest( hass: HomeAssistant, mock_get_station ) -> None: