Skip to content

Commit

Permalink
Implement device factory (#1556)
Browse files Browse the repository at this point in the history
This will make it simple for downstream users to construct device
instances for all supported devices given only the host and its token.

All device subclasses register themselves automatically to the factory.
The create(host, token, model=None) class method is the main entry point
to use this.

Supersedes #1328
Fixes #1117
  • Loading branch information
rytilahti authored Oct 23, 2022
1 parent dc55eba commit 8270db5
Show file tree
Hide file tree
Showing 11 changed files with 176 additions and 21 deletions.
1 change: 1 addition & 0 deletions miio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from miio.cloud import CloudInterface
from miio.cooker import Cooker
from miio.curtain_youpin import CurtainMiot
from miio.devicefactory import DeviceFactory
from miio.gateway import Gateway
from miio.heater import Heater
from miio.heater_miot import HeaterMiot
Expand Down
2 changes: 2 additions & 0 deletions miio/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from miio.miioprotocol import MiIOProtocol

from .cloud import cloud
from .devicefactory import factory
from .devtools import devtools

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -62,6 +63,7 @@ def discover(mdns, handshake, network, timeout):
cli.add_command(discover)
cli.add_command(cloud)
cli.add_command(devtools)
cli.add_command(factory)


def create_cli():
Expand Down
8 changes: 5 additions & 3 deletions miio/click_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging
import re
from functools import partial, wraps
from typing import Callable, Set, Type, Union
from typing import Any, Callable, ClassVar, Dict, List, Set, Type, Union

import click

Expand Down Expand Up @@ -110,6 +110,8 @@ def __init__(self, debug: int = 0, output: Callable = None):
class DeviceGroupMeta(type):

_device_classes: Set[Type] = set()
_supported_models: ClassVar[List[str]]
_mappings: ClassVar[Dict[str, Any]]

def __new__(mcs, name, bases, namespace):
commands = {}
Expand Down Expand Up @@ -146,9 +148,9 @@ def get_device_group(dcls):
return cls

@property
def supported_models(cls):
def supported_models(cls) -> List[str]:
"""Return list of supported models."""
return cls._mappings.keys() or cls._supported_models
return list(cls._mappings.keys()) or cls._supported_models


class DeviceGroup(click.MultiCommand):
Expand Down
8 changes: 8 additions & 0 deletions miio/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ class Device(metaclass=DeviceGroupMeta):
_mappings: Dict[str, Any] = {}
_supported_models: List[str] = []

def __init_subclass__(cls, **kwargs):
"""Overridden to register all integrations to the factory."""
super().__init_subclass__(**kwargs)

from .devicefactory import DeviceFactory

DeviceFactory.register(cls)

def __init__(
self,
ip: str = None,
Expand Down
109 changes: 109 additions & 0 deletions miio/devicefactory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import logging
from typing import Dict, List, Optional, Type

import click

from .device import Device
from .exceptions import DeviceException

_LOGGER = logging.getLogger(__name__)


class DeviceFactory:
"""A helper class to construct devices based on their info responses.
This class keeps list of supported integrations and models to allow creating
:class:`Device` instances without knowing anything except the host and the token.
:func:`create` is the main entry point when using this module. Example::
from miio import DeviceFactory
dev = DeviceFactory.create("127.0.0.1", 32*"0")
"""

_integration_classes: List[Type[Device]] = []
_supported_models: Dict[str, Type[Device]] = {}

@classmethod
def register(cls, integration_cls: Type[Device]):
"""Register class for to the registry."""
cls._integration_classes.append(integration_cls)
_LOGGER.debug("Registering %s", integration_cls.__name__)
for model in integration_cls.supported_models: # type: ignore
if model in cls._supported_models:
_LOGGER.debug(
"Got duplicate of %s for %s, previously registered by %s",
model,
integration_cls,
cls._supported_models[model],
)

_LOGGER.debug(" * %s => %s", model, integration_cls)
cls._supported_models[model] = integration_cls

@classmethod
def supported_models(cls) -> Dict[str, Type[Device]]:
"""Return a dictionary of models and their corresponding implementation
classes."""
return cls._supported_models

@classmethod
def integrations(cls) -> List[Type[Device]]:
"""Return the list of integration classes."""
return cls._integration_classes

@classmethod
def class_for_model(cls, model: str):
"""Return implementation class for the given model, if available."""
if model in cls._supported_models:
return cls._supported_models[model]

wildcard_models = {
m: impl for m, impl in cls._supported_models.items() if m.endswith("*")
}
for wildcard_model, impl in wildcard_models.items():
m = wildcard_model.rstrip("*")
if model.startswith(m):
_LOGGER.debug(
"Using %s for %s, please add it to supported models for %s",
wildcard_model,
model,
impl,
)
return impl

raise DeviceException("No implementation found for model %s" % model)

@classmethod
def create(self, host: str, token: str, model: Optional[str] = None) -> Device:
"""Return instance for the given host and token, with optional model override.
The optional model parameter can be used to override the model detection.
"""
if model is None:
dev: Device = Device(host, token)
info = dev.info()
model = info.model

return self.class_for_model(model)(host, token, model=model)


@click.group()
def factory():
"""Access to available integrations."""


@factory.command()
def integrations():
for integration in DeviceFactory.integrations():
click.echo(
f"* {integration} supports {len(integration.supported_models)} models"
)


@factory.command()
def models():
"""List supported models."""
for model in DeviceFactory.supported_models():
click.echo(f"* {model}")
12 changes: 1 addition & 11 deletions miio/integrations/light/yeelight/spec_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,6 @@ def __init__(self):
self._parse_specs_yaml()

def _parse_specs_yaml(self):
generic_info = YeelightModelInfo(
"generic",
False,
{
YeelightSubLightType.Main: YeelightLampInfo(
ColorTempRange(1700, 6500), False
)
},
)
YeelightSpecHelper._models["generic"] = generic_info
# read the yaml file to populate the internal model cache
with open(os.path.dirname(__file__) + "/specs.yaml") as filedata:
models = yaml.safe_load(filedata)
Expand Down Expand Up @@ -82,5 +72,5 @@ def get_model_info(self, model) -> YeelightModelInfo:
"Unknown model %s, please open an issue and supply features for this light. Returning generic information.",
model,
)
return self._models["generic"]
return self._models["yeelink.light.*"]
return self._models[model]
4 changes: 4 additions & 0 deletions miio/integrations/light/yeelight/specs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,7 @@ yeelink.light.lamp22:
night_light: False
color_temp: [2700, 6500]
supports_color: True
yeelink.light.*:
night_light: False
color_temp: [1700, 6500]
supports_color: False
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_get_model_info():
def test_get_unknown_model_info():
spec_helper = YeelightSpecHelper()
model_info = spec_helper.get_model_info("notreal")
assert model_info.model == "generic"
assert model_info.model == "yeelink.light.*"
assert model_info.night_light is False
assert model_info.lamps[YeelightSubLightType.Main].color_temp == ColorTempRange(
1700, 6500
Expand Down
7 changes: 2 additions & 5 deletions miio/integrations/light/yeelight/yeelight.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,8 @@ class Yeelight(Device):
which however requires enabling the developer mode on the bulbs.
"""

_supported_models: List[str] = []
_spec_helper = None
_spec_helper = YeelightSpecHelper()
_supported_models: List[str] = _spec_helper.supported_models

def __init__(
self,
Expand All @@ -267,9 +267,6 @@ def __init__(
model: str = None,
) -> None:
super().__init__(ip, token, start_id, debug, lazy_discover, model=model)
if Yeelight._spec_helper is None:
Yeelight._spec_helper = YeelightSpecHelper()
Yeelight._supported_models = Yeelight._spec_helper.supported_models

self._model_info = Yeelight._spec_helper.get_model_info(self.model)
self._light_type = YeelightSubLightType.Main
Expand Down
42 changes: 42 additions & 0 deletions miio/tests/test_devicefactory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest

from miio import Device, DeviceException, DeviceFactory, Gateway, MiotDevice

DEVICE_CLASSES = Device.__subclasses__() + MiotDevice.__subclasses__() # type: ignore
DEVICE_CLASSES.remove(MiotDevice)


def test_device_all_supported_models():
models = DeviceFactory.supported_models()
for model, impl in models.items():
assert isinstance(model, str)
assert issubclass(impl, Device)


@pytest.mark.parametrize("cls", DEVICE_CLASSES)
def test_device_class_for_model(cls):
"""Test that all supported models can be initialized using class_for_model."""

if cls == Gateway:
pytest.skip(
"Skipping Gateway as AirConditioningCompanion already implements lumi.acpartner.*"
)

for supp in cls.supported_models:
dev = DeviceFactory.class_for_model(supp)
assert issubclass(dev, cls)


def test_device_class_for_wildcard():
"""Test that wildcard matching works."""

class _DummyDevice(Device):
_supported_models = ["foo.bar.*"]

assert DeviceFactory.class_for_model("foo.bar.aaaa") == _DummyDevice


def test_device_class_for_model_unknown():
"""Test that unknown model raises an exception."""
with pytest.raises(DeviceException):
DeviceFactory.class_for_model("foo.foo.xyz")
2 changes: 1 addition & 1 deletion miio/tests/test_miotdevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def test_mapping_structure(cls):

@pytest.mark.parametrize("cls", MIOT_DEVICES)
def test_supported_models(cls):
assert cls.supported_models == cls._mappings.keys()
assert cls.supported_models == list(cls._mappings.keys())

# make sure that that _supported_models is not defined
assert not cls._supported_models
Expand Down

0 comments on commit 8270db5

Please sign in to comment.