Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Trace to avoid self import #125822

Merged
merged 2 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 8 additions & 124 deletions homeassistant/components/trace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

from __future__ import annotations

from collections.abc import Mapping
import logging
from typing import Any

import voluptuous as vol

Expand All @@ -15,17 +13,16 @@
from homeassistant.helpers.json import ExtendedJSONEncoder
from homeassistant.helpers.storage import Store
from homeassistant.helpers.typing import ConfigType
from homeassistant.util.limited_size_dict import LimitedSizeDict

from . import websocket_api
from .const import (
CONF_STORED_TRACES,
DATA_TRACE,
DATA_TRACE_STORE,
DATA_TRACES_RESTORED,
DEFAULT_STORED_TRACES,
)
from .models import ActionTrace, BaseTrace, RestoredTrace
from .models import ActionTrace
from .util import async_store_trace

_LOGGER = logging.getLogger(__name__)

Expand All @@ -40,7 +37,12 @@

CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)

type TraceData = dict[str, LimitedSizeDict[str, BaseTrace]]
__all__ = [
"CONF_STORED_TRACES",
"TRACE_CONFIG_SCHEMA",
"ActionTrace",
"async_store_trace",
]


async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
Expand Down Expand Up @@ -69,121 +71,3 @@ async def _async_store_traces_at_stop(_: Event) -> None:
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _async_store_traces_at_stop)

return True


async def async_get_trace(
hass: HomeAssistant, key: str, run_id: str
) -> dict[str, BaseTrace]:
"""Return the requested trace."""
# Restore saved traces if not done
await async_restore_traces(hass)

return hass.data[DATA_TRACE][key][run_id].as_extended_dict()


async def async_list_contexts(
hass: HomeAssistant, key: str | None
) -> dict[str, dict[str, str]]:
"""List contexts for which we have traces."""
# Restore saved traces if not done
await async_restore_traces(hass)

values: Mapping[str, LimitedSizeDict[str, BaseTrace] | None] | TraceData
if key is not None:
values = {key: hass.data[DATA_TRACE].get(key)}
else:
values = hass.data[DATA_TRACE]

def _trace_id(run_id: str, key: str) -> dict[str, str]:
"""Make trace_id for the response."""
domain, item_id = key.split(".", 1)
return {"run_id": run_id, "domain": domain, "item_id": item_id}

return {
trace.context.id: _trace_id(trace.run_id, key)
for key, traces in values.items()
if traces is not None
for trace in traces.values()
}


def _get_debug_traces(hass: HomeAssistant, key: str) -> list[dict[str, Any]]:
"""Return a serializable list of debug traces for a script or automation."""
if traces_for_key := hass.data[DATA_TRACE].get(key):
return [trace.as_short_dict() for trace in traces_for_key.values()]
return []


async def async_list_traces(
hass: HomeAssistant, wanted_domain: str, wanted_key: str | None
) -> list[dict[str, Any]]:
"""List traces for a domain."""
# Restore saved traces if not done already
await async_restore_traces(hass)

if not wanted_key:
traces: list[dict[str, Any]] = []
for key in hass.data[DATA_TRACE]:
domain = key.split(".", 1)[0]
if domain == wanted_domain:
traces.extend(_get_debug_traces(hass, key))
else:
traces = _get_debug_traces(hass, wanted_key)

return traces


def async_store_trace(
hass: HomeAssistant, trace: ActionTrace, stored_traces: int
) -> None:
"""Store a trace if its key is valid."""
if key := trace.key:
traces = hass.data[DATA_TRACE]
if key not in traces:
traces[key] = LimitedSizeDict(size_limit=stored_traces)
else:
traces[key].size_limit = stored_traces
traces[key][trace.run_id] = trace


def _async_store_restored_trace(hass: HomeAssistant, trace: RestoredTrace) -> None:
"""Store a restored trace and move it to the end of the LimitedSizeDict."""
key = trace.key
traces = hass.data[DATA_TRACE]
if key not in traces:
traces[key] = LimitedSizeDict()
traces[key][trace.run_id] = trace
traces[key].move_to_end(trace.run_id, last=False)


async def async_restore_traces(hass: HomeAssistant) -> None:
"""Restore saved traces."""
if DATA_TRACES_RESTORED in hass.data:
return

hass.data[DATA_TRACES_RESTORED] = True

store = hass.data[DATA_TRACE_STORE]
try:
restored_traces = await store.async_load() or {}
except HomeAssistantError:
_LOGGER.exception("Error loading traces")
restored_traces = {}

for key, traces in restored_traces.items():
# Add stored traces in reversed order to prioritize the newest traces
for json_trace in reversed(traces):
if (
(stored_traces := hass.data[DATA_TRACE].get(key))
and stored_traces.size_limit is not None
and len(stored_traces) >= stored_traces.size_limit
):
break

try:
trace = RestoredTrace(json_trace)
# Catch any exception to not blow up if the stored trace is invalid
except Exception:
_LOGGER.exception("Failed to restore trace")
continue
_async_store_restored_trace(hass, trace)
2 changes: 1 addition & 1 deletion homeassistant/components/trace/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
if TYPE_CHECKING:
from homeassistant.helpers.storage import Store

from . import TraceData
from .models import TraceData


CONF_STORED_TRACES = "stored_traces"
Expand Down
3 changes: 3 additions & 0 deletions homeassistant/components/trace/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
trace_set_child_id,
)
import homeassistant.util.dt as dt_util
from homeassistant.util.limited_size_dict import LimitedSizeDict
import homeassistant.util.uuid as uuid_util

type TraceData = dict[str, LimitedSizeDict[str, BaseTrace]]


class BaseTrace(abc.ABC):
"""Base container for a script or automation trace."""
Expand Down
134 changes: 134 additions & 0 deletions homeassistant/components/trace/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""Support for script and automation tracing and debugging."""

from __future__ import annotations

from collections.abc import Mapping
import logging
from typing import Any

from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.util.limited_size_dict import LimitedSizeDict

from .const import DATA_TRACE, DATA_TRACE_STORE, DATA_TRACES_RESTORED
from .models import ActionTrace, BaseTrace, RestoredTrace, TraceData

_LOGGER = logging.getLogger(__name__)


async def async_get_trace(
hass: HomeAssistant, key: str, run_id: str
) -> dict[str, BaseTrace]:
"""Return the requested trace."""
# Restore saved traces if not done
await async_restore_traces(hass)

return hass.data[DATA_TRACE][key][run_id].as_extended_dict()


async def async_list_contexts(
hass: HomeAssistant, key: str | None
) -> dict[str, dict[str, str]]:
"""List contexts for which we have traces."""
# Restore saved traces if not done
await async_restore_traces(hass)

values: Mapping[str, LimitedSizeDict[str, BaseTrace] | None] | TraceData
if key is not None:
values = {key: hass.data[DATA_TRACE].get(key)}
else:
values = hass.data[DATA_TRACE]

def _trace_id(run_id: str, key: str) -> dict[str, str]:
"""Make trace_id for the response."""
domain, item_id = key.split(".", 1)
return {"run_id": run_id, "domain": domain, "item_id": item_id}

return {
trace.context.id: _trace_id(trace.run_id, key)
for key, traces in values.items()
if traces is not None
for trace in traces.values()
}


def _get_debug_traces(hass: HomeAssistant, key: str) -> list[dict[str, Any]]:
"""Return a serializable list of debug traces for a script or automation."""
if traces_for_key := hass.data[DATA_TRACE].get(key):
return [trace.as_short_dict() for trace in traces_for_key.values()]
return []


async def async_list_traces(
hass: HomeAssistant, wanted_domain: str, wanted_key: str | None
) -> list[dict[str, Any]]:
"""List traces for a domain."""
# Restore saved traces if not done already
await async_restore_traces(hass)

if not wanted_key:
traces: list[dict[str, Any]] = []
for key in hass.data[DATA_TRACE]:
domain = key.split(".", 1)[0]
if domain == wanted_domain:
traces.extend(_get_debug_traces(hass, key))
else:
traces = _get_debug_traces(hass, wanted_key)

return traces


def async_store_trace(
hass: HomeAssistant, trace: ActionTrace, stored_traces: int
) -> None:
"""Store a trace if its key is valid."""
if key := trace.key:
traces = hass.data[DATA_TRACE]
if key not in traces:
traces[key] = LimitedSizeDict(size_limit=stored_traces)
else:
traces[key].size_limit = stored_traces
traces[key][trace.run_id] = trace


def _async_store_restored_trace(hass: HomeAssistant, trace: RestoredTrace) -> None:
"""Store a restored trace and move it to the end of the LimitedSizeDict."""
key = trace.key
traces = hass.data[DATA_TRACE]
if key not in traces:
traces[key] = LimitedSizeDict()
traces[key][trace.run_id] = trace
traces[key].move_to_end(trace.run_id, last=False)


async def async_restore_traces(hass: HomeAssistant) -> None:
"""Restore saved traces."""
if DATA_TRACES_RESTORED in hass.data:
return

hass.data[DATA_TRACES_RESTORED] = True

store = hass.data[DATA_TRACE_STORE]
try:
restored_traces = await store.async_load() or {}
except HomeAssistantError:
_LOGGER.exception("Error loading traces")
restored_traces = {}

Check warning on line 116 in homeassistant/components/trace/util.py

View check run for this annotation

Codecov / codecov/patch

homeassistant/components/trace/util.py#L114-L116

Added lines #L114 - L116 were not covered by tests

for key, traces in restored_traces.items():
# Add stored traces in reversed order to prioritize the newest traces
for json_trace in reversed(traces):
if (
(stored_traces := hass.data[DATA_TRACE].get(key))
and stored_traces.size_limit is not None
and len(stored_traces) >= stored_traces.size_limit
):
break

try:
trace = RestoredTrace(json_trace)
# Catch any exception to not blow up if the stored trace is invalid
except Exception:
_LOGGER.exception("Failed to restore trace")
continue

Check warning on line 133 in homeassistant/components/trace/util.py

View check run for this annotation

Codecov / codecov/patch

homeassistant/components/trace/util.py#L131-L133

Added lines #L131 - L133 were not covered by tests
_async_store_restored_trace(hass, trace)
8 changes: 4 additions & 4 deletions homeassistant/components/trace/websocket_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
debug_stop,
)

from .. import trace # noqa: TID252 (see PR 125822)
from .util import async_get_trace, async_list_contexts, async_list_traces

TRACE_DOMAINS = ("automation", "script")

Expand Down Expand Up @@ -66,7 +66,7 @@ async def websocket_trace_get(
run_id = msg["run_id"]

try:
requested_trace = await trace.async_get_trace(hass, key, run_id)
requested_trace = await async_get_trace(hass, key, run_id)
except KeyError:
connection.send_error(
msg["id"], websocket_api.ERR_NOT_FOUND, "The trace could not be found"
Expand Down Expand Up @@ -98,7 +98,7 @@ async def websocket_trace_list(
wanted_domain = msg["domain"]
key = f"{msg['domain']}.{msg['item_id']}" if "item_id" in msg else None

traces = await trace.async_list_traces(hass, wanted_domain, key)
traces = await async_list_traces(hass, wanted_domain, key)

connection.send_result(msg["id"], traces)

Expand All @@ -120,7 +120,7 @@ async def websocket_trace_contexts(
"""Retrieve contexts we have traces for."""
key = f"{msg['domain']}.{msg['item_id']}" if "item_id" in msg else None

contexts = await trace.async_list_contexts(hass, key)
contexts = await async_list_contexts(hass, key)

connection.send_result(msg["id"], contexts)

Expand Down