Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make pushserver more generic #1531

Merged
merged 3 commits into from
Sep 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions miio/push_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import socket
from json import dumps
from random import randint
from typing import Callable, Optional
from typing import Callable, Dict, Optional, Union

from ..device import Device
from ..protocol import Utils
Expand Down Expand Up @@ -53,7 +53,7 @@ class PushServer:
await push_server.stop()
"""

def __init__(self, device_ip):
def __init__(self, device_ip=None):
"""Initialize the class."""
self._device_ip = device_ip

Expand All @@ -66,6 +66,8 @@ def __init__(self, device_ip):
self._listen_couroutine = None
self._registered_devices = {}

self._methods = {}

self._event_id = 1000000

async def start(self):
Expand All @@ -76,7 +78,9 @@ async def start(self):

self._loop = asyncio.get_event_loop()

_, self._listen_couroutine = await self._create_udp_server()
transport, self._listen_couroutine = await self._create_udp_server()

return transport, self._listen_couroutine

async def stop(self):
"""Stop Miio push server."""
Expand All @@ -90,6 +94,13 @@ async def stop(self):
self._listen_couroutine = None
self._loop = None

def add_method(self, name: str, response: Union[Dict, Callable]):
"""Add a method to server.

The response can be either a callable or a dictionary to send back as response.
"""
self._methods[name] = response

def register_miio_device(self, device: Device, callback: PushServerCallback):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a check here:

if self._device_ip is None:
   warning
   return

To make sure this is a proper server that can work with real devices

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I'm following what you mean, the device_ip is checked right below?

"""Register a miio device to this push server."""
if device.ip is None:
Expand Down Expand Up @@ -208,7 +219,8 @@ async def _get_server_ip(self):

async def _create_udp_server(self):
"""Create the UDP socket and protocol."""
self._server_ip = await self._get_server_ip()
if self._device_ip is not None:
self._server_ip = await self._get_server_ip()

# Create a fresh socket that will be used for the push server
udp_socket = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
Expand Down Expand Up @@ -311,3 +323,8 @@ def server_id(self):
def server_model(self):
"""Return the model of the fake device beeing emulated."""
return self._server_model

@property
def methods(self):
"""Return a dict of implemented methods."""
return self._methods
110 changes: 77 additions & 33 deletions miio/push_server/serverprotocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,57 +52,101 @@ def send_ping_ACK(self, host, port):
self.transport.sendto(m, (host, port))
_LOGGER.debug("%s:%s<=ACK(server_id=%s)", host, port, self.server.server_id)

def send_msg_OK(self, host, port, msg_id, token):
# This result means OK, but some methods return ['ok'] instead of 0
# might be necessary to use different results for different methods
result = {"result": 0, "id": msg_id}
def _create_message(self, data, token, device_id):
"""Create a message to be sent to the client."""
header = {
"length": 0,
"unknown": 0,
"device_id": self.server.server_id,
"device_id": device_id,
"ts": datetime.datetime.now(),
}
msg = {
"data": {"value": result},
"data": {"value": data},
"header": {"value": header},
"checksum": 0,
}
response = Message.build(msg, token=token)
self.transport.sendto(response, (host, port))

return response

def send_response(self, host, port, msg_id, token, payload=None):
if payload is None:
payload = {}

result = {**payload, "id": msg_id}
rytilahti marked this conversation as resolved.
Show resolved Hide resolved
msg = self._create_message(result, token, device_id=self.server.server_id)

self.transport.sendto(msg, (host, port))
_LOGGER.debug(">> %s:%s: %s", host, port, result)

def datagram_received(self, data, addr):
"""Handle received messages."""
try:
(host, port) = addr
if data == HELO_BYTES:
self.send_ping_ACK(host, port)
return
def send_error(self, host, port, msg_id, token, code, message):
"""Send error message with given code and message to the client."""
return self.send_response(
host, port, msg_id, token, {"error": {"code": code, "error": message}}
)

def _handle_datagram_from_registered_device(self, host, port, data):
"""Handle requests from registered eventing devices."""
token = self.server._registered_devices[host]["token"]
callback = self.server._registered_devices[host]["callback"]

msg = Message.parse(data, token=token)
msg_value = msg.data.value
msg_id = msg_value["id"]
_LOGGER.debug("<< %s:%s: %s", host, port, msg_value)

if host not in self.server._registered_devices:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to have this error back when the host is not registered as a miio_device or a client.
See comment below.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now checked inside datagram_received, for registered devices. But you are right, we are responding to all non-registered devices as if they were real clients..

_LOGGER.warning(
"Datagram received from unknown device (%s:%s)",
host,
port,
)
return
# Send OK
# This result means OK, but some methods return ['ok'] instead of 0
# might be necessary to use different results for different methods
payload = {"result": 0}
self.send_response(host, port, msg_id, token, payload=payload)

# Parse message
action, device_call_id = msg_value["method"].rsplit(":", 1)
source_device_id = device_call_id.replace("_", ".")

callback(source_device_id, action, msg_value.get("params"))

def _handle_datagram_from_client(self, host: str, port: int, data):
"""Handle datagram from a regular client."""
token = bytes.fromhex(32 * "0") # TODO: make token configurable?
msg = Message.parse(data, token=token)
msg_value = msg.data.value
msg_id = msg_value["id"]

_LOGGER.debug(
"Received datagram #%s from regular client: %s: %s",
msg_id,
host,
msg_value,
)

token = self.server._registered_devices[host]["token"]
callback = self.server._registered_devices[host]["callback"]
methods = self.server.methods
if msg_value["method"] not in methods:
return self.send_error(host, port, msg_id, token, -1, "unsupported method")

msg = Message.parse(data, token=token)
msg_value = msg.data.value
msg_id = msg_value["id"]
_LOGGER.debug("<< %s:%s: %s", host, port, msg_value)
method = methods[msg_value["method"]]
if callable(method):
try:
response = method(msg_value)
except Exception as ex:
return self.send_error(host, port, msg_id, token, -1, str(ex))
else:
response = method

# Send OK
self.send_msg_OK(host, port, msg_id, token)
return self.send_response(host, port, msg_id, token, payload=response)

# Parse message
action, device_call_id = msg_value["method"].rsplit(":", 1)
source_device_id = device_call_id.replace("_", ".")
def datagram_received(self, data, addr):
"""Handle received messages."""
try:
(host, port) = addr
if data == HELO_BYTES:
return self.send_ping_ACK(host, port)

callback(source_device_id, action, msg_value.get("params"))
if host in self.server._registered_devices:
return self._handle_datagram_from_registered_device(host, port, data)
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am slightly worried that we might exedently send incorrect messages to a real device if that device is for instance not properly disconnected and the server is restarted or other edge cases.
Maybe we schould make a register_client function and check against that instead of just handeling everything which is not registerd as a miio_device as a client.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The incorrect messages would be encrypted with a hardcoded token that should be different from a real device, so sending a message to the server with some other will simply cause the device not to respond. That is why I don't think it's a non-issue.

return self._handle_datagram_from_client(host, port, data)

except Exception:
_LOGGER.exception(
Expand Down
122 changes: 122 additions & 0 deletions miio/push_server/test_serverprotocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import pytest

from miio import Message

from .serverprotocol import ServerProtocol

HOST = "127.0.0.1"
PORT = 1234
SERVER_ID = 4141
DUMMY_TOKEN = bytes.fromhex("0" * 32)


@pytest.fixture
def protocol(mocker, event_loop) -> ServerProtocol:
server = mocker.Mock()

# Mock server id
type(server).server_id = mocker.PropertyMock(return_value=SERVER_ID)
socket = mocker.Mock()

proto = ServerProtocol(event_loop, socket, server)
proto.transport = mocker.Mock()

yield proto


def test_send_ping_ack(protocol: ServerProtocol, mocker):
"""Test that ping acks are send as expected."""
protocol.send_ping_ACK(HOST, PORT)
protocol.transport.sendto.assert_called()

cargs = protocol.transport.sendto.call_args[0]

m = Message.parse(cargs[0])
assert int.from_bytes(m.header.value.device_id, "big") == SERVER_ID
assert m.data.length == 0

assert cargs[1][0] == HOST
assert cargs[1][1] == PORT


def test_send_response(protocol: ServerProtocol):
"""Test that send_response sends valid messages."""
payload = {"foo": 1}
protocol.send_response(HOST, PORT, 1, DUMMY_TOKEN, payload)
protocol.transport.sendto.assert_called()

cargs = protocol.transport.sendto.call_args[0]
m = Message.parse(cargs[0], token=DUMMY_TOKEN)
payload = m.data.value
assert payload["id"] == 1
assert payload["foo"] == 1


def test_send_error(protocol: ServerProtocol, mocker):
"""Test that error payloads are created correctly."""
ERR_MSG = "example error"
ERR_CODE = -1
protocol.send_error(HOST, PORT, 1, DUMMY_TOKEN, code=ERR_CODE, message=ERR_MSG)
protocol.send_response = mocker.Mock() # type: ignore[assignment]
protocol.transport.sendto.assert_called()

cargs = protocol.transport.sendto.call_args[0]
m = Message.parse(cargs[0], token=DUMMY_TOKEN)
payload = m.data.value

assert "error" in payload
assert payload["error"]["code"] == ERR_CODE
assert payload["error"]["error"] == ERR_MSG


def test__handle_datagram_from_registered_device(protocol: ServerProtocol, mocker):
"""Test that events from registered devices are handled correctly."""
protocol.server._registered_devices = {HOST: {}}
protocol.server._registered_devices[HOST]["token"] = DUMMY_TOKEN
dummy_callback = mocker.Mock()
protocol.server._registered_devices[HOST]["callback"] = dummy_callback

PARAMS = {"test_param": 1}
payload = {"id": 1, "method": "action:source_device", "params": PARAMS}
msg_from_device = protocol._create_message(payload, DUMMY_TOKEN, 4242)

protocol._handle_datagram_from_registered_device(HOST, PORT, msg_from_device)

# Assert that a response is sent back
protocol.transport.sendto.assert_called()

# Assert that the callback is called
dummy_callback.assert_called()
cargs = dummy_callback.call_args[0]
assert cargs[2] == PARAMS
assert cargs[0] == "source.device"
assert cargs[1] == "action"


def test_datagram_with_known_method(protocol: ServerProtocol, mocker):
"""Test that regular client messages are handled properly."""
protocol.send_response = mocker.Mock() # type: ignore[assignment]

response_payload = {"result": "info response"}
protocol.server.methods = {"miIO.info": response_payload}

msg = protocol._create_message({"id": 1, "method": "miIO.info"}, DUMMY_TOKEN, 1234)
protocol._handle_datagram_from_client(HOST, PORT, msg)

protocol.send_response.assert_called() # type: ignore
cargs = protocol.send_response.call_args[1] # type: ignore
assert cargs["payload"] == response_payload


def test_datagram_with_unknown_method(protocol: ServerProtocol, mocker):
"""Test that regular client messages are handled properly."""
protocol.send_error = mocker.Mock() # type: ignore[assignment]
protocol.server.methods = {}

msg = protocol._create_message({"id": 1, "method": "miIO.info"}, DUMMY_TOKEN, 1234)
protocol._handle_datagram_from_client(HOST, PORT, msg)

protocol.send_error.assert_called() # type: ignore
cargs = protocol.send_error.call_args[0] # type: ignore
assert cargs[4] == -1
assert cargs[5] == "unsupported method"
20 changes: 19 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ docs = ["sphinx", "sphinx_click", "sphinxcontrib-apidoc", "sphinx_rtd_theme"]
pytest = ">=6.2.5"
pytest-cov = "^2"
pytest-mock = "^3"
pytest-asyncio = "*"
voluptuous = "^0"
pre-commit = "^2"
doc8 = "^0"
Expand Down