diff --git a/miio/push_server/server.py b/miio/push_server/server.py index 4e69014f5..9a21cef31 100644 --- a/miio/push_server/server.py +++ b/miio/push_server/server.py @@ -3,7 +3,7 @@ import socket from json import dumps from random import randint -from typing import Callable, Optional +from typing import Callable, Dict, Optional, Union from ..device import Device from ..protocol import Utils @@ -53,7 +53,7 @@ class PushServer: await push_server.stop() """ - def __init__(self, device_ip): + def __init__(self, device_ip=None): """Initialize the class.""" self._device_ip = device_ip @@ -66,6 +66,8 @@ def __init__(self, device_ip): self._listen_couroutine = None self._registered_devices = {} + self._methods = {} + self._event_id = 1000000 async def start(self): @@ -76,7 +78,9 @@ async def start(self): self._loop = asyncio.get_event_loop() - _, self._listen_couroutine = await self._create_udp_server() + transport, self._listen_couroutine = await self._create_udp_server() + + return transport, self._listen_couroutine async def stop(self): """Stop Miio push server.""" @@ -90,6 +94,13 @@ async def stop(self): self._listen_couroutine = None self._loop = None + def add_method(self, name: str, response: Union[Dict, Callable]): + """Add a method to server. + + The response can be either a callable or a dictionary to send back as response. + """ + self._methods[name] = response + def register_miio_device(self, device: Device, callback: PushServerCallback): """Register a miio device to this push server.""" if device.ip is None: @@ -208,7 +219,8 @@ async def _get_server_ip(self): async def _create_udp_server(self): """Create the UDP socket and protocol.""" - self._server_ip = await self._get_server_ip() + if self._device_ip is not None: + self._server_ip = await self._get_server_ip() # Create a fresh socket that will be used for the push server udp_socket = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM) @@ -311,3 +323,8 @@ def server_id(self): def server_model(self): """Return the model of the fake device beeing emulated.""" return self._server_model + + @property + def methods(self): + """Return a dict of implemented methods.""" + return self._methods diff --git a/miio/push_server/serverprotocol.py b/miio/push_server/serverprotocol.py index eefc85629..2bc43597d 100644 --- a/miio/push_server/serverprotocol.py +++ b/miio/push_server/serverprotocol.py @@ -52,57 +52,101 @@ def send_ping_ACK(self, host, port): self.transport.sendto(m, (host, port)) _LOGGER.debug("%s:%s<=ACK(server_id=%s)", host, port, self.server.server_id) - def send_msg_OK(self, host, port, msg_id, token): - # This result means OK, but some methods return ['ok'] instead of 0 - # might be necessary to use different results for different methods - result = {"result": 0, "id": msg_id} + def _create_message(self, data, token, device_id): + """Create a message to be sent to the client.""" header = { "length": 0, "unknown": 0, - "device_id": self.server.server_id, + "device_id": device_id, "ts": datetime.datetime.now(), } msg = { - "data": {"value": result}, + "data": {"value": data}, "header": {"value": header}, "checksum": 0, } response = Message.build(msg, token=token) - self.transport.sendto(response, (host, port)) + + return response + + def send_response(self, host, port, msg_id, token, payload=None): + if payload is None: + payload = {} + + result = {**payload, "id": msg_id} + msg = self._create_message(result, token, device_id=self.server.server_id) + + self.transport.sendto(msg, (host, port)) _LOGGER.debug(">> %s:%s: %s", host, port, result) - def datagram_received(self, data, addr): - """Handle received messages.""" - try: - (host, port) = addr - if data == HELO_BYTES: - self.send_ping_ACK(host, port) - return + def send_error(self, host, port, msg_id, token, code, message): + """Send error message with given code and message to the client.""" + return self.send_response( + host, port, msg_id, token, {"error": {"code": code, "error": message}} + ) + + def _handle_datagram_from_registered_device(self, host, port, data): + """Handle requests from registered eventing devices.""" + token = self.server._registered_devices[host]["token"] + callback = self.server._registered_devices[host]["callback"] + + msg = Message.parse(data, token=token) + msg_value = msg.data.value + msg_id = msg_value["id"] + _LOGGER.debug("<< %s:%s: %s", host, port, msg_value) - if host not in self.server._registered_devices: - _LOGGER.warning( - "Datagram received from unknown device (%s:%s)", - host, - port, - ) - return + # Send OK + # This result means OK, but some methods return ['ok'] instead of 0 + # might be necessary to use different results for different methods + payload = {"result": 0} + self.send_response(host, port, msg_id, token, payload=payload) + + # Parse message + action, device_call_id = msg_value["method"].rsplit(":", 1) + source_device_id = device_call_id.replace("_", ".") + + callback(source_device_id, action, msg_value.get("params")) + + def _handle_datagram_from_client(self, host: str, port: int, data): + """Handle datagram from a regular client.""" + token = bytes.fromhex(32 * "0") # TODO: make token configurable? + msg = Message.parse(data, token=token) + msg_value = msg.data.value + msg_id = msg_value["id"] + + _LOGGER.debug( + "Received datagram #%s from regular client: %s: %s", + msg_id, + host, + msg_value, + ) - token = self.server._registered_devices[host]["token"] - callback = self.server._registered_devices[host]["callback"] + methods = self.server.methods + if msg_value["method"] not in methods: + return self.send_error(host, port, msg_id, token, -1, "unsupported method") - msg = Message.parse(data, token=token) - msg_value = msg.data.value - msg_id = msg_value["id"] - _LOGGER.debug("<< %s:%s: %s", host, port, msg_value) + method = methods[msg_value["method"]] + if callable(method): + try: + response = method(msg_value) + except Exception as ex: + return self.send_error(host, port, msg_id, token, -1, str(ex)) + else: + response = method - # Send OK - self.send_msg_OK(host, port, msg_id, token) + return self.send_response(host, port, msg_id, token, payload=response) - # Parse message - action, device_call_id = msg_value["method"].rsplit(":", 1) - source_device_id = device_call_id.replace("_", ".") + def datagram_received(self, data, addr): + """Handle received messages.""" + try: + (host, port) = addr + if data == HELO_BYTES: + return self.send_ping_ACK(host, port) - callback(source_device_id, action, msg_value.get("params")) + if host in self.server._registered_devices: + return self._handle_datagram_from_registered_device(host, port, data) + else: + return self._handle_datagram_from_client(host, port, data) except Exception: _LOGGER.exception( diff --git a/miio/push_server/test_serverprotocol.py b/miio/push_server/test_serverprotocol.py new file mode 100644 index 000000000..37ce3bd63 --- /dev/null +++ b/miio/push_server/test_serverprotocol.py @@ -0,0 +1,122 @@ +import pytest + +from miio import Message + +from .serverprotocol import ServerProtocol + +HOST = "127.0.0.1" +PORT = 1234 +SERVER_ID = 4141 +DUMMY_TOKEN = bytes.fromhex("0" * 32) + + +@pytest.fixture +def protocol(mocker, event_loop) -> ServerProtocol: + server = mocker.Mock() + + # Mock server id + type(server).server_id = mocker.PropertyMock(return_value=SERVER_ID) + socket = mocker.Mock() + + proto = ServerProtocol(event_loop, socket, server) + proto.transport = mocker.Mock() + + yield proto + + +def test_send_ping_ack(protocol: ServerProtocol, mocker): + """Test that ping acks are send as expected.""" + protocol.send_ping_ACK(HOST, PORT) + protocol.transport.sendto.assert_called() + + cargs = protocol.transport.sendto.call_args[0] + + m = Message.parse(cargs[0]) + assert int.from_bytes(m.header.value.device_id, "big") == SERVER_ID + assert m.data.length == 0 + + assert cargs[1][0] == HOST + assert cargs[1][1] == PORT + + +def test_send_response(protocol: ServerProtocol): + """Test that send_response sends valid messages.""" + payload = {"foo": 1} + protocol.send_response(HOST, PORT, 1, DUMMY_TOKEN, payload) + protocol.transport.sendto.assert_called() + + cargs = protocol.transport.sendto.call_args[0] + m = Message.parse(cargs[0], token=DUMMY_TOKEN) + payload = m.data.value + assert payload["id"] == 1 + assert payload["foo"] == 1 + + +def test_send_error(protocol: ServerProtocol, mocker): + """Test that error payloads are created correctly.""" + ERR_MSG = "example error" + ERR_CODE = -1 + protocol.send_error(HOST, PORT, 1, DUMMY_TOKEN, code=ERR_CODE, message=ERR_MSG) + protocol.send_response = mocker.Mock() # type: ignore[assignment] + protocol.transport.sendto.assert_called() + + cargs = protocol.transport.sendto.call_args[0] + m = Message.parse(cargs[0], token=DUMMY_TOKEN) + payload = m.data.value + + assert "error" in payload + assert payload["error"]["code"] == ERR_CODE + assert payload["error"]["error"] == ERR_MSG + + +def test__handle_datagram_from_registered_device(protocol: ServerProtocol, mocker): + """Test that events from registered devices are handled correctly.""" + protocol.server._registered_devices = {HOST: {}} + protocol.server._registered_devices[HOST]["token"] = DUMMY_TOKEN + dummy_callback = mocker.Mock() + protocol.server._registered_devices[HOST]["callback"] = dummy_callback + + PARAMS = {"test_param": 1} + payload = {"id": 1, "method": "action:source_device", "params": PARAMS} + msg_from_device = protocol._create_message(payload, DUMMY_TOKEN, 4242) + + protocol._handle_datagram_from_registered_device(HOST, PORT, msg_from_device) + + # Assert that a response is sent back + protocol.transport.sendto.assert_called() + + # Assert that the callback is called + dummy_callback.assert_called() + cargs = dummy_callback.call_args[0] + assert cargs[2] == PARAMS + assert cargs[0] == "source.device" + assert cargs[1] == "action" + + +def test_datagram_with_known_method(protocol: ServerProtocol, mocker): + """Test that regular client messages are handled properly.""" + protocol.send_response = mocker.Mock() # type: ignore[assignment] + + response_payload = {"result": "info response"} + protocol.server.methods = {"miIO.info": response_payload} + + msg = protocol._create_message({"id": 1, "method": "miIO.info"}, DUMMY_TOKEN, 1234) + protocol._handle_datagram_from_client(HOST, PORT, msg) + + protocol.send_response.assert_called() # type: ignore + cargs = protocol.send_response.call_args[1] # type: ignore + assert cargs["payload"] == response_payload + + +def test_datagram_with_unknown_method(protocol: ServerProtocol, mocker): + """Test that regular client messages are handled properly.""" + protocol.send_error = mocker.Mock() # type: ignore[assignment] + protocol.server.methods = {} + + msg = protocol._create_message({"id": 1, "method": "miIO.info"}, DUMMY_TOKEN, 1234) + protocol._handle_datagram_from_client(HOST, PORT, msg) + + protocol.send_error.assert_called() # type: ignore + cargs = protocol.send_error.call_args[0] # type: ignore + assert cargs[4] == -1 + assert cargs[5] == "unsupported method" diff --git a/poetry.lock b/poetry.lock index 25d7c3b20..a6cdc0dd3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -522,6 +522,20 @@ tomli = ">=1.0.0" [package.extras] testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.19.0" +description = "Pytest support for asyncio" +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +pytest = ">=6.1.0" + +[package.extras] +testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"] + [[package]] name = "pytest-cov" version = "2.12.1" @@ -967,7 +981,7 @@ docs = ["sphinx", "sphinx_click", "sphinxcontrib-apidoc", "sphinx_rtd_theme"] [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "c9fcfce783eee7f667e48c0e01d96c4da3a999fd5a7d5fbf88b16960234d9e57" +content-hash = "0367a3767c8d8b6d3b82b49cf730920109e9a59945ffab7ea5b7535e7642d9c8" [metadata.files] alabaster = [ @@ -1422,6 +1436,10 @@ pytest = [ {file = "pytest-7.1.3-py3-none-any.whl", hash = "sha256:1377bda3466d70b55e3f5cecfa55bb7cfcf219c7964629b967c37cf0bda818b7"}, {file = "pytest-7.1.3.tar.gz", hash = "sha256:4f365fec2dff9c1162f834d9f18af1ba13062db0c708bf7b946f8a5c76180c39"}, ] +pytest-asyncio = [ + {file = "pytest-asyncio-0.19.0.tar.gz", hash = "sha256:ac4ebf3b6207259750bc32f4c1d8fcd7e79739edbc67ad0c58dd150b1d072fed"}, + {file = "pytest_asyncio-0.19.0-py3-none-any.whl", hash = "sha256:7a97e37cfe1ed296e2e84941384bdd37c376453912d397ed39293e0916f521fa"}, +] pytest-cov = [ {file = "pytest-cov-2.12.1.tar.gz", hash = "sha256:261ceeb8c227b726249b376b8526b600f38667ee314f910353fa318caa01f4d7"}, {file = "pytest_cov-2.12.1-py2.py3-none-any.whl", hash = "sha256:261bb9e47e65bd099c89c3edf92972865210c36813f80ede5277dceb77a4a62a"}, diff --git a/pyproject.toml b/pyproject.toml index 9b8abf7e5..a0699161a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ docs = ["sphinx", "sphinx_click", "sphinxcontrib-apidoc", "sphinx_rtd_theme"] pytest = ">=6.2.5" pytest-cov = "^2" pytest-mock = "^3" +pytest-asyncio = "*" voluptuous = "^0" pre-commit = "^2" doc8 = "^0"