Skip to content

Commit

Permalink
Allow nested schema validation in event automation trigger
Browse files Browse the repository at this point in the history
  • Loading branch information
Hypfer committed Sep 25, 2024
1 parent d7ac53a commit a563e93
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 8 deletions.
24 changes: 16 additions & 8 deletions homeassistant/components/homeassistant/triggers/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,23 @@ async def async_attach_trigger(
event_data.update(
template.render_complex(config[CONF_EVENT_DATA], variables, limited=True)
)
# Build the schema or a an items view if the schema is simple
# and does not contain sub-dicts. We explicitly do not check for
# list like the context data below since lists are a special case
# only for context data. (see test test_event_data_with_list)

def build_schema(data):

Check failure on line 79 in homeassistant/components/homeassistant/triggers/event.py

View workflow job for this annotation

GitHub Actions / Check mypy

Function is missing a type annotation [no-untyped-def]
schema = {}
for key, value in data.items():
if isinstance(value, dict):
schema[vol.Required(key)] = build_schema(value)

Check failure on line 83 in homeassistant/components/homeassistant/triggers/event.py

View workflow job for this annotation

GitHub Actions / Check mypy

Call to untyped function "build_schema" in typed context [no-untyped-call]
else:
schema[vol.Required(key)] = value
return vol.Schema(schema, extra=vol.ALLOW_EXTRA)

# For performance reasons, we want to avoid using a voluptuous schema here unless required
# Thus, if possible, we try to use a simple items comparison
# For that, we explicitly do not check for list like the context data below since lists are
# a special case only used for context data. (see test test_event_data_with_list)
# Otherwise, we recursively build the schema (see test test_event_data_with_list_nested)
if any(isinstance(value, dict) for value in event_data.values()):
event_data_schema = vol.Schema(
{vol.Required(key): value for key, value in event_data.items()},
extra=vol.ALLOW_EXTRA,
)
event_data_schema = build_schema(event_data)

Check failure on line 94 in homeassistant/components/homeassistant/triggers/event.py

View workflow job for this annotation

GitHub Actions / Check mypy

Call to untyped function "build_schema" in typed context [no-untyped-call]
else:
# Use a simple items comparison if possible
event_data_items = event_data.items()
Expand Down
45 changes: 45 additions & 0 deletions tests/components/homeassistant/triggers/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,51 @@ async def test_event_data_with_list(
await hass.async_block_till_done()
assert len(service_calls) == 1

# don't match if property doesn't exist at all
hass.bus.async_fire("test_event", {"other_attr": [1, 2]})
await hass.async_block_till_done()
assert len(service_calls) == 1


async def test_event_data_with_list_nested(
hass: HomeAssistant, service_calls: list[ServiceCall]
) -> None:
"""Test the (non)firing of event when the data schema has nested lists."""
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"trigger": {
"platform": "event",
"event_type": "test_event",
"event_data": {"service_data": {"some_attr": [1, 2]}},
"context": {},
},
"action": {"service": "test.automation"},
}
},
)

hass.bus.async_fire("test_event", {"service_data": {"some_attr": [1, 2]}})
await hass.async_block_till_done()
assert len(service_calls) == 1

# don't match a single value
hass.bus.async_fire("test_event", {"service_data": {"some_attr": 1}})
await hass.async_block_till_done()
assert len(service_calls) == 1

# don't match a containing list
hass.bus.async_fire("test_event", {"service_data": {"some_attr": [1, 2, 3]}})
await hass.async_block_till_done()
assert len(service_calls) == 1

# don't match if property doesn't exist at all
hass.bus.async_fire("test_event", {"service_data": {"other_attr": [1, 2]}})
await hass.async_block_till_done()
assert len(service_calls) == 1


@pytest.mark.parametrize(
"event_type", ["state_reported", ["test_event", "state_reported"]]
Expand Down

0 comments on commit a563e93

Please sign in to comment.