From 186f325741166185a55088f7dbdfc084f5068ed5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zbigniew=20Pie=C5=9Blak?= Date: Sat, 19 Oct 2024 21:53:31 +0200 Subject: [PATCH] Initialize mqtt client inside App.call Refactor mqtt_client Move subclassing of paho.mqtt.client into direct callback assign for better code visibility. Extend Client API to move some logic from App. Update tests --- README.md | 2 +- mobilus_client/app.py | 42 +++--- mobilus_client/client.py | 123 ++++++++++++++++ mobilus_client/mqtt_client.py | 91 ------------ tests/test_app.py | 55 +++---- tests/{test_mqtt_client.py => test_client.py} | 135 ++++++++++-------- 6 files changed, 237 insertions(+), 211 deletions(-) create mode 100644 mobilus_client/client.py delete mode 100644 mobilus_client/mqtt_client.py rename tests/{test_mqtt_client.py => test_client.py} (62%) diff --git a/README.md b/README.md index 07379bd..f00bc05 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ The tests can be run with the following command: To generate a coverage report: coverage run -m unittest discover - coverage report + coverage report -m To regenerate proto files diff --git a/mobilus_client/app.py b/mobilus_client/app.py index 210eb08..b6ba531 100644 --- a/mobilus_client/app.py +++ b/mobilus_client/app.py @@ -2,9 +2,9 @@ import secrets import socket +from mobilus_client.client import Client from mobilus_client.config import Config from mobilus_client.messages.serializer import MessageSerializer -from mobilus_client.mqtt_client import MqttClient from mobilus_client.registries.key import KeyRegistry from mobilus_client.registries.message import MessageRegistry @@ -14,50 +14,44 @@ class App: def __init__(self, config: Config) -> None: self.config = config - self.message_registry = MessageRegistry() - self.key_registry = KeyRegistry(config.user_key) - self.client = MqttClient( - client_id=secrets.token_hex(6).upper(), - transport=config.gateway_protocol, - userdata={ - "config": config, - "key_registry": self.key_registry, - "message_registry": self.message_registry, - }, - ) def call(self, commands: list[tuple[str, dict[str, str]]]) -> str: if not commands: return self._empty_response() - try: - # Connect to the MQTT broker and start the loop - self.client.connect(self.config.gateway_host, self.config.gateway_port, 60) - self.client.loop_start() + # Initialize client and registries + key_registry = KeyRegistry(self.config.user_key) + message_registry = MessageRegistry() - # Wait for the client to authenticate - self.client.authenticated_event.wait(timeout=self.config.auth_timeout_period) + client = Client( + client_id=secrets.token_hex(6).upper(), + config=self.config, + key_registry=key_registry, + message_registry=message_registry, + ) - if not self.client.authenticated_event.is_set(): - logger.error("Failed to authenticate with the gateway host") + try: + # Connect to the MQTT broker and authenticate + if not client.connect_and_authenticate(): return self._empty_response() + # Execute the provided commands for command, params in commands: - self.client.send_request(command, **params) + client.send_request(command, **params) # Wait for the completion event to be triggered - self.client.completed_event.wait(timeout=self.config.timeout_period) + client.completed_event.wait(timeout=self.config.timeout_period) except socket.gaierror: logger.error("Failed to connect to the gateway host") except TimeoutError: logger.error("Timeout occurred") finally: - self.client.disconnect() + client.terminate() # Return serialized responses from the message registry return MessageSerializer.serialize_list_to_json( - self.message_registry.get_responses(), + message_registry.get_responses(), ) def _empty_response(self) -> str: diff --git a/mobilus_client/client.py b/mobilus_client/client.py new file mode 100644 index 0000000..8899780 --- /dev/null +++ b/mobilus_client/client.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +import logging +import threading +from typing import TYPE_CHECKING, Any + +import paho.mqtt.client as mqtt + +from mobilus_client.messages.encryptor import MessageEncryptor +from mobilus_client.messages.factory import MessageFactory +from mobilus_client.messages.status import MessageStatus +from mobilus_client.messages.validator import MessageValidator +from mobilus_client.proto import LoginRequest, LoginResponse + +if TYPE_CHECKING: + from mobilus_client.config import Config + from mobilus_client.registries.key import KeyRegistry + from mobilus_client.registries.message import MessageRegistry + +logger = logging.getLogger(__name__) + + +class Client: + def __init__( + self, client_id: str, config: Config, key_registry: KeyRegistry, message_registry: MessageRegistry) -> None: + self.config = config + self.client_id = client_id + self.key_registry = key_registry + self.message_registry = message_registry + self.authenticated_event = threading.Event() + self.completed_event = threading.Event() + + self.mqtt_client = mqtt.Client(client_id=self.client_id, transport=config.gateway_protocol) + self._configure_client() + + def connect_and_authenticate(self) -> bool: + self.mqtt_client.connect(self.config.gateway_host, self.config.gateway_port) + self.mqtt_client.loop_start() + + # Wait for the client to authenticate + self.authenticated_event.wait(timeout=self.config.auth_timeout_period) + + if not self.authenticated_event.is_set(): + logger.error("Failed to authenticate with the gateway host") + return False + + return True + + def send_request(self, command: str, **params: str | bytes | int | None) -> None: + if not self.mqtt_client.is_connected(): + logger.error("Sending request - %s failed. Client is not connected.", command) + return + + message = MessageFactory.create_message(command, **params) + status = MessageValidator.validate(message) + + if status != MessageStatus.SUCCESS or message is None: + logger.error("Command - %s returned an error - %s", command, status.name) + self.terminate() + return + + if not isinstance(message, LoginRequest): + self.message_registry.register_request(message) + + encrypted_message = MessageEncryptor.encrypt( + message, + self.client_id, + self.key_registry, + ) + + self.mqtt_client.publish("module", encrypted_message) + + def terminate(self) -> None: + self.mqtt_client.disconnect() + self.mqtt_client.loop_stop() + + def on_disconnect_callback(self, _client: mqtt.Client, _userdata: None, reason_code: int) -> None: + logger.info("Disconnected with result code - %s", reason_code) + + def on_connect_callback( + self, _client: mqtt.Client, _userdata: None, _flags: dict[str, Any], _reason_code: int) -> None: + self.mqtt_client.subscribe([ + (self.client_id, 0), + ("clients", 0), + ]) + + def on_subscribe_callback(self, _client: mqtt.Client, _userdata: None, _mid: int, _granted_qos: tuple[int]) -> None: + self.send_request( + "login", + login=self.config.user_login, + password=self.config.user_key, + ) + + def on_message_callback(self, _client: mqtt.Client, _userdata: None, mqtt_message: mqtt.MQTTMessage) -> None: + logger.info("Received message on topic - %s", mqtt_message.topic) + + message = MessageEncryptor.decrypt(mqtt_message.payload, self.key_registry) + logger.info("Decrypted message - %s", type(message).__name__) + + status = MessageValidator.validate(message) + + if status != MessageStatus.SUCCESS or message is None: + logger.error("Message - %s returned an error - %s", type(message).__name__, status.name) + self.terminate() + return + + logger.info("Message - %s validated successfully", type(message).__name__) + + if isinstance(message, LoginResponse): + self.key_registry.register_keys(message) + self.authenticated_event.set() + else: + self.message_registry.register_response(message) + + if self.message_registry.all_responses_received(): + self.completed_event.set() + + def _configure_client(self) -> None: + self.mqtt_client.enable_logger(logger) + self.mqtt_client.on_connect = self.on_connect_callback + self.mqtt_client.on_disconnect = self.on_disconnect_callback + self.mqtt_client.on_subscribe = self.on_subscribe_callback + self.mqtt_client.on_message = self.on_message_callback diff --git a/mobilus_client/mqtt_client.py b/mobilus_client/mqtt_client.py deleted file mode 100644 index 731b62a..0000000 --- a/mobilus_client/mqtt_client.py +++ /dev/null @@ -1,91 +0,0 @@ -from __future__ import annotations - -import logging -import threading -from typing import Any, cast - -import paho.mqtt.client as mqtt - -from mobilus_client.messages.encryptor import MessageEncryptor -from mobilus_client.messages.factory import MessageFactory -from mobilus_client.messages.status import MessageStatus -from mobilus_client.messages.validator import MessageValidator -from mobilus_client.proto import LoginRequest, LoginResponse -from mobilus_client.utils.types import MessageRequest - -logger = logging.getLogger(__name__) - - -class MqttClient(mqtt.Client): - _client_id: bytes - _userdata: dict[str, Any] - - def __init__(self, **kwargs: Any) -> None: # noqa: ANN401 - super().__init__(**kwargs) - self.authenticated_event = threading.Event() - self.completed_event = threading.Event() - self.enable_logger(logger) - - def send_request(self, command: str, **params: str | bytes | int | None) -> None: - if not self.is_connected(): - logger.error("Sending request - %s failed. Client is not connected.", command) - return - - message = MessageFactory.create_message(command, **params) - status = MessageValidator.validate(message) - - if status != MessageStatus.SUCCESS: - logger.error("Command - %s returned an error - %s", command, status.name) - self.disconnect() - return - - if not isinstance(message, LoginRequest): - self._userdata["message_registry"].register_request(message) - - encrypted_message = MessageEncryptor.encrypt( - cast(MessageRequest, message), - self._client_id.decode(), - self._userdata["key_registry"], - ) - - self.publish("module", encrypted_message) - - def on_disconnect(self, _client: mqtt.Client, _userdata: dict[str, Any], reason_code: int) -> None: # type: ignore[override] - logger.info("Disconnected with result code - %s", reason_code) - - def on_connect(self, client: mqtt.Client, _userdata: dict[str, Any], *_args: Any) -> None: # type: ignore[override] # noqa: ANN401 - client.subscribe([ - (self._client_id.decode(), 0), - ("clients", 0), - ]) - - def on_subscribe(self, _client: mqtt.Client, userdata: dict[str, Any], *_args: Any) -> None: # type: ignore[override] # noqa: ANN401 - self.send_request( - "login", - login=userdata["config"].user_login, - password=userdata["config"].user_key, - ) - - def on_message(self, _client: mqtt.Client, userdata: dict[str, Any], mqtt_message: mqtt.MQTTMessage) -> None: # type: ignore[override] - logger.info("Received message on topic - %s", mqtt_message.topic) - - message = MessageEncryptor.decrypt(mqtt_message.payload, userdata["key_registry"]) - logger.info("Decrypted message - %s", type(message).__name__) - - status = MessageValidator.validate(message) - - if status != MessageStatus.SUCCESS: - logger.error("Message - %s returned an error - %s", type(message).__name__, status.name) - self.disconnect() - return - - logger.info("Message - %s validated successfully", type(message).__name__) - - if isinstance(message, LoginResponse): - userdata["key_registry"].register_keys(message) - self.authenticated_event.set() - else: - userdata["message_registry"].register_response(message) - - if userdata["message_registry"].all_responses_received(): - self.completed_event.set() diff --git a/tests/test_app.py b/tests/test_app.py index 66e0caa..d27aaf3 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -1,11 +1,10 @@ import socket import unittest -from unittest.mock import Mock, patch +from unittest.mock import ANY, Mock, patch from mobilus_client.app import App +from mobilus_client.client import Client from mobilus_client.config import Config -from mobilus_client.mqtt_client import MqttClient -from mobilus_client.registries.key import KeyRegistry from mobilus_client.registries.message import MessageRegistry from tests.factories import ( CallEventsRequestFactory, @@ -24,67 +23,55 @@ def setUp(self) -> None: def test_init(self) -> None: self.assertEqual(self.app.config, self.config) - self.assertIsInstance(self.app.message_registry, MessageRegistry) - self.assertIsInstance(self.app.key_registry, KeyRegistry) - self.assertIsInstance(self.app.client, MqttClient) - @patch.object(MqttClient, "connect", side_effect=socket.gaierror, autospec=True) + @patch.object(Client, "connect_and_authenticate", side_effect=socket.gaierror, autospec=True) def test_call_with_invalid_gateway_host(self, _mock_connect: Mock) -> None: result = self.app.call([("call_events", {})]) self.assertEqual(result, "[]") - @patch.object(MqttClient, "connect", side_effect=TimeoutError, autospec=True) + @patch.object(Client, "connect_and_authenticate", side_effect=TimeoutError, autospec=True) def test_call_with_timeout_gateway_host(self, _mock_connect: Mock) -> None: result = self.app.call([("call_events", {})]) self.assertEqual(result, "[]") - @patch.object(MqttClient, "connect", return_value=Mock(), autospec=True) + @patch.object(Client, "connect_and_authenticate", return_value=Mock(), autospec=True) def test_call_with_empty_commands(self, mock_connect: Mock) -> None: result = self.app.call([]) mock_connect.assert_not_called() self.assertEqual(result, "[]") - @patch.object(MqttClient, "connect", return_value=Mock(), autospec=True) - @patch.object(MqttClient, "loop_start", return_value=Mock(), autospec=True) - def test_call_with_not_authenticated(self, _mock_loop_start: Mock, _mock_connect: Mock) -> None: - self.config.auth_timeout_period = 0.0005 + @patch.object(Client, "connect_and_authenticate", return_value=False, autospec=True) + def test_call_with_not_authenticated(self, _mock_connect: Mock) -> None: result = self.app.call([("call_events", {})]) self.assertEqual(result, "[]") - @patch.object(MqttClient, "connect", return_value=Mock(), autospec=True) - @patch.object(MqttClient, "loop_start", return_value=Mock(), autospec=True) - @patch.object(MqttClient, "disconnect", return_value=Mock(), autospec=True) - def test_call_with_wrong_commands(self, mock_disconnect: Mock, mock_loop_start: Mock, mock_connect: Mock) -> None: - self.app.client.authenticated_event.set() - self.app.client.completed_event.set() + @patch.object(Client, "connect_and_authenticate", return_value=True, autospec=True) + @patch.object(Client, "terminate", return_value=Mock(), autospec=True) + def test_call_with_wrong_commands(self, mock_terminate: Mock, mock_connect: Mock) -> None: result = self.app.call([("wrong", {})]) - mock_connect.assert_called_once_with(self.app.client, "host", 8884, 60) - mock_loop_start.assert_called_once() - mock_disconnect.assert_called_once() + mock_connect.assert_called_once() + mock_terminate.assert_called_once() self.assertEqual(result, "[]") - @patch.object(MqttClient, "connect", return_value=Mock(), autospec=True) - @patch.object(MqttClient, "loop_start", return_value=Mock(), autospec=True) - @patch.object(MqttClient, "send_request", return_value=Mock(), autospec=True) - @patch.object(MqttClient, "disconnect", return_value=Mock(), autospec=True) + @patch.object(Client, "connect_and_authenticate", return_value=True, autospec=True) + @patch.object(Client, "send_request", return_value=Mock(), autospec=True) + @patch.object(Client, "terminate", return_value=Mock(), autospec=True) + @patch.object(MessageRegistry, "get_responses", autospec=True) def test_call_with_commands( - self, mock_disconnect: Mock, mock_send_request: Mock, mock_loop_start: Mock, mock_connect: Mock) -> None: - self.app.client.authenticated_event.set() - self.app.client.completed_event.set() + self, mock_get_responses: Mock, mock_terminate: Mock, mock_send_request: Mock, mock_connect: Mock) -> None: call_events_request = CallEventsRequestFactory( event={"device_id": 1, "event_number": 1, "value": "value", "platform": 1}, ) - self.app.message_registry.register_response(call_events_request) + mock_get_responses.return_value = [call_events_request] result = self.app.call([("call_events", {"device_id": "1", "value": "value"})]) - mock_connect.assert_called_once_with(self.app.client, "host", 8884, 60) - mock_loop_start.assert_called_once() - mock_send_request.assert_called_once_with(self.app.client, "call_events", device_id="1", value="value") - mock_disconnect.assert_called_once() + mock_connect.assert_called_once() + mock_send_request.assert_called_once_with(ANY, "call_events", device_id="1", value="value") + mock_terminate.assert_called_once() self.assertEqual(result, '[{"events": [{"deviceId": "1", "eventNumber": 1, "value": "value", "platform": 1}]}]') diff --git a/tests/test_mqtt_client.py b/tests/test_client.py similarity index 62% rename from tests/test_mqtt_client.py rename to tests/test_client.py index 83dbd8d..319d599 100644 --- a/tests/test_mqtt_client.py +++ b/tests/test_client.py @@ -2,8 +2,10 @@ import unittest from unittest.mock import ANY, Mock, patch +import paho.mqtt.client as mqtt + +from mobilus_client.client import Client from mobilus_client.config import Config -from mobilus_client.mqtt_client import MqttClient from mobilus_client.proto import ( CallEventsRequest, CurrentStateRequest, @@ -20,45 +22,63 @@ from tests.helpers import encrypt_message -class TestMQTTClient(unittest.TestCase): +class TestClient(unittest.TestCase): def setUp(self) -> None: self.client_id = "0123456789ABCDEF" self.config = Config( + auth_timeout_period=0.0005, gateway_host="host", user_login="login", user_password="password", timeout_period=0, ) - self.message_registry = MessageRegistry() self.key_registry = KeyRegistry(self.config.user_key) - self.client = MqttClient( + + self.client = Client( client_id=self.client_id, - transport=self.config.gateway_protocol, - userdata={ - "config": self.config, - "key_registry": self.key_registry, - "message_registry": self.message_registry, - }, + config=self.config, + key_registry=self.key_registry, + message_registry=self.message_registry, ) def test_init(self) -> None: - self.client = MqttClient() - - self.assertIsInstance(self.client, MqttClient) + self.assertEqual(self.client.config, self.config) + self.assertEqual(self.client.key_registry, self.key_registry) + self.assertEqual(self.client.message_registry, self.message_registry) self.assertIsInstance(self.client.authenticated_event, threading.Event) self.assertIsInstance(self.client.completed_event, threading.Event) + self.assertIsInstance(self.client.mqtt_client, mqtt.Client) + + @patch.object(mqtt.Client, "connect", return_value=Mock(), autospec=True) + @patch.object(mqtt.Client, "loop_start", return_value=Mock(), autospec=True) + def test_connect_and_authenticate_false(self, _mock_loop_start: Mock, mock_connect: Mock) -> None: + result = self.client.connect_and_authenticate() + + mock_connect.assert_called_once_with( + self.client.mqtt_client, self.config.gateway_host, self.config.gateway_port) + self.assertFalse(result) - @patch.object(MqttClient, "is_connected", return_value=False, autospec=True) - @patch.object(MqttClient, "publish", return_value=Mock(), autospec=True) + @patch.object(mqtt.Client, "connect", return_value=Mock(), autospec=True) + @patch.object(mqtt.Client, "loop_start", return_value=Mock(), autospec=True) + def test_connect_and_authenticate_true(self, _mock_loop_start: Mock, mock_connect: Mock) -> None: + self.client.authenticated_event.set() + result = self.client.connect_and_authenticate() + + mock_connect.assert_called_once_with( + self.client.mqtt_client, self.config.gateway_host, self.config.gateway_port) + self.assertTrue(result) + + @patch.object(mqtt.Client, "is_connected", return_value=False, autospec=True) + @patch.object(mqtt.Client, "publish", return_value=Mock(), autospec=True) def test_send_request_when_not_connected(self, mock_publish: Mock, _mock_is_connected: Mock) -> None: self.client.send_request("login", login="user", password=self.config.user_key) mock_publish.assert_not_called() - @patch.object(MqttClient, "is_connected", return_value=True, autospec=True) - @patch.object(MqttClient, "disconnect", return_value=Mock(), autospec=True) - @patch.object(MqttClient, "publish", return_value=Mock(), autospec=True) + @patch.object(mqtt.Client, "is_connected", return_value=True, autospec=True) + @patch.object(mqtt.Client, "disconnect", return_value=Mock(), autospec=True) + @patch.object(mqtt.Client, "publish", return_value=Mock(), autospec=True) def test_send_request_with_wrong_command( self, mock_publish: Mock, mock_disconnect: Mock, _mock_is_connected: Mock) -> None: self.client.send_request("fake") @@ -67,15 +87,15 @@ def test_send_request_with_wrong_command( mock_publish.assert_not_called() @patch("time.time", return_value=1633036800) - @patch.object(MqttClient, "is_connected", return_value=True, autospec=True) - @patch.object(MqttClient, "publish", return_value=Mock(), autospec=True) + @patch.object(mqtt.Client, "is_connected", return_value=True, autospec=True) + @patch.object(mqtt.Client, "publish", return_value=Mock(), autospec=True) def test_send_request_with_login_request( self, mock_publish: Mock, _mock_is_connected: Mock, _mock_time: Mock) -> None: self.client.send_request("login", login="user", password=self.config.user_key) self.assertEqual(self.message_registry.get_requests(), []) mock_publish.assert_called_once_with( - self.client, + self.client.mqtt_client, "module", ( b"\x00\x00\x00\r\x01aV*\x00\x01#Eg\x89\xab\xcd\xef\x04\x00\n\x04user\x12 " @@ -84,8 +104,8 @@ def test_send_request_with_login_request( ) @patch("time.time", return_value=1633036800) - @patch.object(MqttClient, "is_connected", return_value=True, autospec=True) - @patch.object(MqttClient, "publish", return_value=Mock(), autospec=True) + @patch.object(mqtt.Client, "is_connected", return_value=True, autospec=True) + @patch.object(mqtt.Client, "publish", return_value=Mock(), autospec=True) def test_send_request_with_call_events_request( self, mock_publish: Mock, _mock_is_connected: Mock, _mock_time: Mock) -> None: login_response = LoginResponseFactory(private_key=b"test_private_key") @@ -95,14 +115,14 @@ def test_send_request_with_call_events_request( self.assertIsInstance(self.message_registry.get_requests()[0], CallEventsRequest) mock_publish.assert_called_once_with( - self.client, + self.client.mqtt_client, "module", b"\x00\x00\x00\r\raV*\x00\x01#Eg\x89\xab\xcd\xef\x04\x003\xc2\x06\xb5\x9f&\x1b\xfcj2\xd2_\xd9", ) @patch("time.time", return_value=1633036800) - @patch.object(MqttClient, "is_connected", return_value=True, autospec=True) - @patch.object(MqttClient, "publish", return_value=Mock(), autospec=True) + @patch.object(mqtt.Client, "is_connected", return_value=True, autospec=True) + @patch.object(mqtt.Client, "publish", return_value=Mock(), autospec=True) def test_send_request_with_current_state_request( self, mock_publish: Mock, _mock_is_connected: Mock, _mock_time: Mock) -> None: login_response = LoginResponseFactory(private_key=b"test_private_key") @@ -112,14 +132,14 @@ def test_send_request_with_current_state_request( self.assertIsInstance(self.message_registry.get_requests()[0], CurrentStateRequest) mock_publish.assert_called_once_with( - self.client, + self.client.mqtt_client, "module", b"\x00\x00\x00\r\x1aaV*\x00\x01#Eg\x89\xab\xcd\xef\x04\x00", ) @patch("time.time", return_value=1633036800) - @patch.object(MqttClient, "is_connected", return_value=True, autospec=True) - @patch.object(MqttClient, "publish", return_value=Mock(), autospec=True) + @patch.object(mqtt.Client, "is_connected", return_value=True, autospec=True) + @patch.object(mqtt.Client, "publish", return_value=Mock(), autospec=True) def test_send_request_with_devices_list_request( self, mock_publish: Mock, _mock_is_connected: Mock, _mock_time: Mock) -> None: login_response = LoginResponseFactory(private_key=b"test_private_key") @@ -129,41 +149,49 @@ def test_send_request_with_devices_list_request( self.assertIsInstance(self.message_registry.get_requests()[0], DevicesListRequest) mock_publish.assert_called_once_with( - self.client, + self.client.mqtt_client, "module", b"\x00\x00\x00\r\x03aV*\x00\x01#Eg\x89\xab\xcd\xef\x04\x00", ) + @patch.object(mqtt.Client, "disconnect", return_value=Mock(), autospec=True) + @patch.object(mqtt.Client, "loop_stop", return_value=Mock(), autospec=True) + def test_terminate(self, mock_loop_stop: Mock, mock_disconnect: Mock) -> None: + self.client.terminate() + + mock_disconnect.assert_called_once() + mock_loop_stop.assert_called_once() + @patch("logging.Logger.info", return_value=Mock(), autospec=True) - def test_on_disconnect(self, mock_info: Mock) -> None: - self.client.on_disconnect(Mock(), {}, 0) + def test_on_disconnect_callback(self, mock_info: Mock) -> None: + self.client.on_disconnect_callback(self.client.mqtt_client, None, 0) mock_info.assert_called_once_with(ANY, "Disconnected with result code - %s", 0) - @patch.object(MqttClient, "subscribe", return_value=Mock(), autospec=True) - def test_on_connect(self, mock_subscribe: Mock) -> None: - self.client.on_connect(self.client, {"config": self.config}, None, 0) + @patch.object(mqtt.Client, "subscribe", return_value=Mock(), autospec=True) + def test_on_connect_callback(self, mock_subscribe: Mock) -> None: + self.client.on_connect_callback(self.client.mqtt_client, None, {}, 0) mock_subscribe.assert_called_once_with( - self.client, + self.client.mqtt_client, [ (self.client_id, 0), ("clients", 0), ], ) - @patch.object(MqttClient, "send_request", return_value=Mock(), autospec=True) - def test_on_subscribe(self, mock_send_request: Mock) -> None: - self.client.on_subscribe(self.client, {"config": self.config}, 0, None) + @patch.object(Client, "send_request", return_value=Mock(), autospec=True) + def test_on_subscribe_callback(self, mock_send_request: Mock) -> None: + self.client.on_subscribe_callback(self.client.mqtt_client, None, 0, (0,)) mock_send_request.assert_called_once_with( self.client, "login", login=self.config.user_login, password=self.config.user_key) - @patch.object(MqttClient, "disconnect", return_value=Mock(), autospec=True) - def test_on_message_invalid(self, mock_disconnect: Mock) -> None: + @patch.object(mqtt.Client, "disconnect", return_value=Mock(), autospec=True) + def test_on_message_callback_invalid(self, mock_disconnect: Mock) -> None: message = Mock(payload=b"invalid") - self.client.on_message(self.client, {"config": self.config, "key_registry": self.key_registry}, message) + self.client.on_message_callback(self.client.mqtt_client, None, message) mock_disconnect.assert_called_once() @@ -172,7 +200,7 @@ def test_on_message_login_response(self) -> None: encrypted_message = encrypt_message(login_response, self.config.user_key) message = Mock(payload=encrypted_message) - self.client.on_message(self.client, {"config": self.config, "key_registry": self.key_registry}, message) + self.client.on_message_callback(self.client.mqtt_client, None, message) self.assertTrue(self.client.authenticated_event.is_set()) self.assertEqual(self.key_registry.get_keys(), { @@ -190,12 +218,7 @@ def test_on_message_call_events_request_all_completed(self) -> None: encrypted_message = encrypt_message(call_events_request, login_response.public_key) message = Mock(payload=encrypted_message) - self.client.on_message( - self.client, { - "config": self.config, - "key_registry": self.key_registry, - "message_registry": self.message_registry, - }, message) + self.client.on_message_callback(self.client.mqtt_client, None, message) self.assertEqual(self.message_registry.get_responses(), [call_events_request]) self.assertTrue(self.client.completed_event.is_set()) @@ -207,12 +230,7 @@ def test_on_message_current_state_response_not_all_completed(self) -> None: encrypted_message = encrypt_message(current_state_response, login_response.private_key) message = Mock(payload=encrypted_message) - self.client.on_message( - self.client, { - "config": self.config, - "key_registry": self.key_registry, - "message_registry": self.message_registry, - }, message) + self.client.on_message_callback(self.client.mqtt_client, None, message) self.assertEqual(self.message_registry.get_responses(), [current_state_response]) self.assertFalse(self.client.completed_event.is_set()) @@ -224,11 +242,6 @@ def test_on_message_devices_list_response_not_all_completed(self) -> None: encrypted_message = encrypt_message(devices_list_response, login_response.private_key) message = Mock(payload=encrypted_message) - self.client.on_message( - self.client, { - "config": self.config, - "key_registry": self.key_registry, - "message_registry": self.message_registry, - }, message) + self.client.on_message_callback(self.client.mqtt_client, None, message) self.assertEqual(self.message_registry.get_responses(), [devices_list_response]) self.assertFalse(self.client.completed_event.is_set())