Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert ZCL attributes from strings when writing #267

Merged
merged 4 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 84 additions & 2 deletions tests/test_application_helpers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""Test zha application helpers."""

from typing import Any

import pytest
from zigpy.device import Device as ZigpyDevice
from zigpy.profiles import zha
from zigpy.zcl.clusters.general import Basic, OnOff
import zigpy.types as t
from zigpy.zcl.clusters.general import Basic, Identify, OnOff
from zigpy.zcl.clusters.security import IasZone

from tests.common import (
Expand All @@ -14,7 +18,12 @@
join_zigpy_device,
)
from zha.application.gateway import Gateway
from zha.application.helpers import async_is_bindable_target, get_matched_clusters
from zha.application.helpers import (
async_is_bindable_target,
convert_to_zcl_values,
convert_zcl_value,
get_matched_clusters,
)

IEEE_GROUPABLE_DEVICE = "01:2d:6f:00:0a:90:69:e8"
IEEE_GROUPABLE_DEVICE2 = "02:2d:6f:00:0a:90:69:e8"
Expand Down Expand Up @@ -105,3 +114,76 @@ async def test_get_matched_clusters(
assert matches[0].target_ep_id == 1

assert not await get_matched_clusters(not_bindable_zha_device, remote_zha_device)


class SomeEnum(t.enum8):
"""Some enum."""

value_1 = 0x12
value_2 = 0x34
value_3 = 0x56


class SomeFlag(t.bitmap8):
"""Some bitmap."""

flag_1 = 0b00000001
flag_2 = 0b00000010
flag_3 = 0b00000100


@pytest.mark.parametrize(
("text", "field_type", "result"),
[
# Bytes
(
"b'Some data\\x00\\x01'",
t.SerializableBytes,
t.SerializableBytes(b"Some data\x00\x01"),
),
(
'b"Some data\\x00\\x01"',
t.SerializableBytes,
t.SerializableBytes(b"Some data\x00\x01"),
),
(
b"Some data\x00\x01".hex(),
t.SerializableBytes,
t.SerializableBytes(b"Some data\x00\x01"),
),
# Enum
("value 1", SomeEnum, SomeEnum.value_1),
("value_1", SomeEnum, SomeEnum.value_1),
("SomeEnum.value_1", SomeEnum, SomeEnum.value_1),
(0x12, SomeEnum, SomeEnum.value_1),
# Flag
("flag 1", SomeFlag, SomeFlag.flag_1),
("flag_1", SomeFlag, SomeFlag.flag_1),
("SomeFlag.flag_1", SomeFlag, SomeFlag.flag_1),
("SomeFlag.flag_1|flag_2", SomeFlag, SomeFlag.flag_1 | SomeFlag.flag_2),
(0b00000001, SomeFlag, SomeFlag.flag_1),
([0b00000001], SomeFlag, SomeFlag.flag_1),
([0b00000001, 0b00000010], SomeFlag, SomeFlag.flag_1 | SomeFlag.flag_2),
(["flag_1", "flag_2"], SomeFlag, SomeFlag.flag_1 | SomeFlag.flag_2),
# Int
(0x1234, t.uint16_t, 0x1234),
("0x1234", t.uint16_t, 0x1234),
("4660", t.uint16_t, 0x1234),
# Some fallthrough type
(1.000, t.Single, t.Single(1.000)),
("1.000", t.Single, t.Single(1.000)),
],
)
def test_convert_zcl_value(text: Any, field_type: Any, result: Any) -> None:
"""Test converting ZCL values."""
assert convert_zcl_value(text, field_type) == result


def test_convert_to_zcl_values() -> None:
"""Test converting ZCL values."""

identify_schema = Identify.ServerCommandDefs.identify.schema
assert convert_to_zcl_values(
fields={"identify_time": "1"},
schema=identify_schema,
) == {"identify_time": 1}
72 changes: 53 additions & 19 deletions zha/application/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from __future__ import annotations

import ast
import asyncio
import binascii
import collections
from collections.abc import Callable
import contextlib
import dataclasses
from dataclasses import dataclass
import datetime
Expand Down Expand Up @@ -126,6 +128,53 @@ async def get_matched_clusters(
return clusters_to_bind


def convert_zcl_value(value: Any, field_type: Any) -> Any:
"""Convert user input to ZCL value."""
if issubclass(field_type, enum.Flag):
if isinstance(value, str):
with contextlib.suppress(ValueError):
value = int(value)

if isinstance(value, int):
value = field_type(value)
elif isinstance(value, str):
# List of flags: `SomeFlag.field1 | field2`
value = [v.strip() for v in value.split(".", 1)[-1].split("|")]

if isinstance(value, list):
new_value = 0

for flag in value:
if isinstance(flag, str):
new_value |= field_type[flag.replace(" ", "_")]
else:
new_value |= flag

value = field_type(new_value)
elif issubclass(field_type, enum.Enum):
value = (
field_type[value.replace(" ", "_").split(".", 1)[-1]]
if isinstance(value, str)
else field_type(value)
)
elif issubclass(field_type, zigpy.types.SerializableBytes):
if value.startswith(("b'", 'b"')):
value = ast.literal_eval(value)
else:
value = bytes.fromhex(value)

value = field_type(value)
elif issubclass(field_type, int):
if isinstance(value, str) and value.startswith("0x"):
value = int(value, 16)

value = field_type(value)
else:
value = field_type(value)

return value


def convert_to_zcl_values(
fields: dict[str, Any], schema: CommandSchema
) -> dict[str, Any]:
Expand All @@ -134,32 +183,17 @@ def convert_to_zcl_values(
for field in schema.fields:
if field.name not in fields:
continue
value = fields[field.name]
if issubclass(field.type, enum.Flag) and isinstance(value, list):
new_value = 0

for flag in value:
if isinstance(flag, str):
new_value |= field.type[flag.replace(" ", "_")]
else:
new_value |= flag
value = fields[field.name]
new_value = converted_fields[field.name] = convert_zcl_value(value, field.type)

value = field.type(new_value)
elif issubclass(field.type, enum.Enum):
value = (
field.type[value.replace(" ", "_")]
if isinstance(value, str)
else field.type(value)
)
else:
value = field.type(value)
_LOGGER.debug(
"Converted ZCL schema field(%s) value from: %s to: %s",
field.name,
fields[field.name],
value,
new_value,
)
converted_fields[field.name] = value

return converted_fields


Expand Down
5 changes: 4 additions & 1 deletion zha/zigbee/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
ZHA_CLUSTER_HANDLER_MSG,
ZHA_EVENT,
)
from zha.application.helpers import convert_to_zcl_values
from zha.application.helpers import convert_to_zcl_values, convert_zcl_value
from zha.application.platforms import BaseEntityInfo, PlatformEntity
from zha.event import EventBase
from zha.exceptions import ZHAException
Expand Down Expand Up @@ -865,6 +865,9 @@ async def write_zigbee_attribute(
f" writing attribute {attribute} with value {value}"
) from exc

attr_def = cluster.find_attribute(attribute)
value = convert_zcl_value(value, attr_def.type)

try:
response = await cluster.write_attributes(
{attribute: value}, manufacturer=manufacturer
Expand Down