diff --git a/miio/push_server/server.py b/miio/push_server/server.py index 9a21cef31..ddd906eab 100644 --- a/miio/push_server/server.py +++ b/miio/push_server/server.py @@ -17,6 +17,7 @@ FAKE_DEVICE_MODEL = "chuangmi.plug.v3" PushServerCallback = Callable[[str, str, str], None] +MethodDict = Dict[str, Union[Dict, Callable]] def calculated_token_enc(token): @@ -66,7 +67,7 @@ def __init__(self, device_ip=None): self._listen_couroutine = None self._registered_devices = {} - self._methods = {} + self._methods: MethodDict = {} self._event_id = 1000000 @@ -325,6 +326,6 @@ def server_model(self): return self._server_model @property - def methods(self): + def methods(self) -> MethodDict: """Return a dict of implemented methods.""" return self._methods diff --git a/miio/push_server/serverprotocol.py b/miio/push_server/serverprotocol.py index 2bc43597d..6e2459e08 100644 --- a/miio/push_server/serverprotocol.py +++ b/miio/push_server/serverprotocol.py @@ -11,6 +11,10 @@ "21310020ffffffffffffffffffffffffffffffffffffffffffffffffffffffff" ) +ERR_INVALID = -1 +ERR_UNSUPPORTED = -2 +ERR_METHOD_EXEC_FAILED = -3 + class ServerProtocol: """Handle responding to UDP packets.""" @@ -73,11 +77,11 @@ def send_response(self, host, port, msg_id, token, payload=None): if payload is None: payload = {} - result = {**payload, "id": msg_id} - msg = self._create_message(result, token, device_id=self.server.server_id) + data = {**payload, "id": msg_id} + msg = self._create_message(data, token, device_id=self.server.server_id) self.transport.sendto(msg, (host, port)) - _LOGGER.debug(">> %s:%s: %s", host, port, result) + _LOGGER.debug(">> %s:%s: %s", host, port, data) def send_error(self, host, port, msg_id, token, code, message): """Send error message with given code and message to the client.""" @@ -121,19 +125,36 @@ def _handle_datagram_from_client(self, host: str, port: int, data): msg_value, ) + if "method" not in msg_value: + return self.send_error( + host, port, msg_id, token, ERR_INVALID, "missing method" + ) + methods = self.server.methods if msg_value["method"] not in methods: - return self.send_error(host, port, msg_id, token, -1, "unsupported method") + return self.send_error( + host, port, msg_id, token, ERR_UNSUPPORTED, "unsupported method" + ) + _LOGGER.debug("Got method call: %s", msg_value["method"]) method = methods[msg_value["method"]] if callable(method): try: response = method(msg_value) except Exception as ex: - return self.send_error(host, port, msg_id, token, -1, str(ex)) + _LOGGER.exception(ex) + return self.send_error( + host, + port, + msg_id, + token, + ERR_METHOD_EXEC_FAILED, + f"Exception {type(ex)}: {ex}", + ) else: response = method + _LOGGER.debug("Responding %s with %s", msg_id, response) return self.send_response(host, port, msg_id, token, payload=response) def datagram_received(self, data, addr): diff --git a/miio/push_server/test_serverprotocol.py b/miio/push_server/test_serverprotocol.py index 37ce3bd63..42fa18132 100644 --- a/miio/push_server/test_serverprotocol.py +++ b/miio/push_server/test_serverprotocol.py @@ -2,7 +2,12 @@ from miio import Message -from .serverprotocol import ServerProtocol +from .serverprotocol import ( + ERR_INVALID, + ERR_METHOD_EXEC_FAILED, + ERR_UNSUPPORTED, + ServerProtocol, +) HOST = "127.0.0.1" PORT = 1234 @@ -108,15 +113,44 @@ def test_datagram_with_known_method(protocol: ServerProtocol, mocker): assert cargs["payload"] == response_payload -def test_datagram_with_unknown_method(protocol: ServerProtocol, mocker): - """Test that regular client messages are handled properly.""" +@pytest.mark.parametrize( + "method,err_code", [("unknown_method", ERR_UNSUPPORTED), (None, ERR_INVALID)] +) +def test_datagram_with_unknown_method( + method, err_code, protocol: ServerProtocol, mocker +): + """Test that invalid payloads are erroring out correctly.""" protocol.send_error = mocker.Mock() # type: ignore[assignment] protocol.server.methods = {} - msg = protocol._create_message({"id": 1, "method": "miIO.info"}, DUMMY_TOKEN, 1234) + data = {"id": 1} + + if method is not None: + data["method"] = method + + msg = protocol._create_message(data, DUMMY_TOKEN, 1234) + protocol._handle_datagram_from_client(HOST, PORT, msg) + + protocol.send_error.assert_called() # type: ignore + cargs = protocol.send_error.call_args[0] # type: ignore + assert cargs[4] == err_code + + +def test_datagram_with_exception_raising(protocol: ServerProtocol, mocker): + """Test that exception raising callbacks are .""" + protocol.send_error = mocker.Mock() # type: ignore[assignment] + + def _raise(*args, **kwargs): + raise Exception("error message") + + protocol.server.methods = {"raise": _raise} + + data = {"id": 1, "method": "raise"} + + msg = protocol._create_message(data, DUMMY_TOKEN, 1234) protocol._handle_datagram_from_client(HOST, PORT, msg) protocol.send_error.assert_called() # type: ignore cargs = protocol.send_error.call_args[0] # type: ignore - assert cargs[4] == -1 - assert cargs[5] == "unsupported method" + assert cargs[4] == ERR_METHOD_EXEC_FAILED + assert "error message" in cargs[5]