diff --git a/mobilus_client/client.py b/mobilus_client/client.py index 0343c4f..039a314 100644 --- a/mobilus_client/client.py +++ b/mobilus_client/client.py @@ -96,17 +96,16 @@ def on_subscribe_callback(self, _client: mqtt.Client, _userdata: None, _mid: int def on_message_callback(self, _client: mqtt.Client, _userdata: None, mqtt_message: mqtt.MQTTMessage) -> None: logger.info("Received message on topic - %s", mqtt_message.topic) - # Ignore messages from the shared topic sent by other clients until the client is authenticated. - if mqtt_message.topic == self.shared_topic and not self.authenticated_event.is_set(): - logger.info("Client is not authenticated yet. Ignoring message.") - return - message = MessageEncryptor.decrypt(mqtt_message.payload, self.key_registry) logger.info("Decrypted message - %s", type(message).__name__) + if message is None: + logger.info("Failed to decrypt message, ignoring") + return + status = MessageValidator.validate(message) - if status != MessageStatus.SUCCESS or message is None: + if status != MessageStatus.SUCCESS: logger.error("Message - %s returned an error - %s", type(message).__name__, status.name) self.terminate() return @@ -116,7 +115,7 @@ def on_message_callback(self, _client: mqtt.Client, _userdata: None, mqtt_messag if isinstance(message, LoginResponse): self.key_registry.register_keys(message) self.authenticated_event.set() - else: + elif self.message_registry.is_expected_response(message): self.message_registry.register_response(message) if self.message_registry.all_responses_received(): diff --git a/mobilus_client/messages/encryptor.py b/mobilus_client/messages/encryptor.py index 96c4647..84c8428 100644 --- a/mobilus_client/messages/encryptor.py +++ b/mobilus_client/messages/encryptor.py @@ -86,6 +86,9 @@ def decrypt(encrypted_message: bytes, key_registry: KeyRegistry) -> MessageRespo # Choose proper decryption key key = key_registry.get_decryption_key(message_klass) + if not key: + return None + # Decrypt body iv = create_iv(timestamp) body = decrypt_body(key, iv, encrypted_body) diff --git a/mobilus_client/registries/key.py b/mobilus_client/registries/key.py index fe04543..cb917ed 100644 --- a/mobilus_client/registries/key.py +++ b/mobilus_client/registries/key.py @@ -12,6 +12,8 @@ class KeyRegistry: def __init__(self, user_key: bytes) -> None: self._registry = { "user_key": user_key, + "private_key": b"", + "public_key": b"", } def get_keys(self) -> dict[str, bytes]: diff --git a/mobilus_client/registries/message.py b/mobilus_client/registries/message.py index b5b7828..bebc179 100644 --- a/mobilus_client/registries/message.py +++ b/mobilus_client/registries/message.py @@ -37,6 +37,12 @@ def get_requests(self) -> list[MessageRequest]: def get_responses(self) -> list[MessageResponse]: return self._responses + def is_expected_response(self, message_response: MessageResponse) -> bool: + return any( + isinstance(message_response, self.MESSAGE_MAP[type(request)]) + for request in self.get_requests() + ) + def all_responses_received(self) -> bool: expected_responses = Counter(self.MESSAGE_MAP[type(request)] for request in self.get_requests()) actual_responses = Counter(type(response) for response in self.get_responses()) diff --git a/tests/messages/test_encryptor.py b/tests/messages/test_encryptor.py index 4d7bd02..8523054 100644 --- a/tests/messages/test_encryptor.py +++ b/tests/messages/test_encryptor.py @@ -167,3 +167,12 @@ def test_decrypt_invalid_message_with_invalid_key(self) -> None: sys.stdout.write(f"result:{type(result)}:") self.assertIsNone(result) + + def test_decrypt_invalid_message_with_no_public_key(self) -> None: + self.key_registry = KeyRegistry(self.user_key) + message = CallEventsRequestFactory() + encrypted_message = encrypt_message(message, b"test_invalid_key") + + result = MessageEncryptor.decrypt(encrypted_message, self.key_registry) + + self.assertIsNone(result) diff --git a/tests/test_client.py b/tests/test_client.py index 2d8e299..85319f9 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -187,13 +187,13 @@ def test_on_subscribe_callback(self, mock_send_request: Mock) -> None: mock_send_request.assert_called_once_with( self.client, "login", login=self.config.user_login, password=self.config.user_key) - @patch.object(mqtt.Client, "disconnect", return_value=Mock(), autospec=True) - def test_on_message_callback_invalid(self, mock_disconnect: Mock) -> None: + @patch("logging.Logger.info", return_value=Mock(), autospec=True) + def test_on_message_callback_invalid(self, mock_info: Mock) -> None: message = Mock(payload=b"invalid") self.client.on_message_callback(self.client.mqtt_client, None, message) - mock_disconnect.assert_called_once() + mock_info.assert_called_with(ANY, "Failed to decrypt message, ignoring") def test_on_message_login_response(self) -> None: login_response = LoginResponseFactory() @@ -209,6 +209,17 @@ def test_on_message_login_response(self) -> None: "public_key": login_response.public_key, }) + @patch.object(Client, "terminate", return_value=Mock(), autospec=True) + def test_on_message_login_response_unauthenticated(self, mock_terminate: Mock) -> None: + login_response = LoginResponseFactory(failed=True) + encrypted_message = encrypt_message(login_response, self.config.user_key) + message = Mock(payload=encrypted_message) + + self.client.on_message_callback(self.client.mqtt_client, None, message) + + self.assertFalse(self.client.authenticated_event.is_set()) + mock_terminate.assert_called_once() + def test_on_message_call_events_request_all_completed(self) -> None: login_response = LoginResponseFactory(public_key=b"test_public_key_") self.key_registry.register_keys(login_response) @@ -222,9 +233,11 @@ def test_on_message_call_events_request_all_completed(self) -> None: self.assertEqual(self.message_registry.get_responses(), [call_events_request]) self.assertTrue(self.client.completed_event.is_set()) - def test_on_message_current_state_response_not_all_completed(self) -> None: + def test_on_message_current_state_response_all_completed(self) -> None: login_response = LoginResponseFactory(private_key=b"test_private_key") self.key_registry.register_keys(login_response) + current_state_request = CurrentStateRequest() + self.message_registry.register_request(current_state_request) current_state_response = CurrentStateResponseFactory() encrypted_message = encrypt_message(current_state_response, login_response.private_key) @@ -232,11 +245,15 @@ def test_on_message_current_state_response_not_all_completed(self) -> None: self.client.on_message_callback(self.client.mqtt_client, None, message) self.assertEqual(self.message_registry.get_responses(), [current_state_response]) - self.assertFalse(self.client.completed_event.is_set()) + self.assertTrue(self.client.completed_event.is_set()) def test_on_message_devices_list_response_not_all_completed(self) -> None: login_response = LoginResponseFactory(private_key=b"test_private_key") self.key_registry.register_keys(login_response) + device_list_request = DevicesListRequest() + current_state_request = CurrentStateRequest() + self.message_registry.register_request(device_list_request) + self.message_registry.register_request(current_state_request) devices_list_response = DevicesListResponseFactory() encrypted_message = encrypt_message(devices_list_response, login_response.private_key) @@ -245,11 +262,3 @@ def test_on_message_devices_list_response_not_all_completed(self) -> None: self.client.on_message_callback(self.client.mqtt_client, None, message) self.assertEqual(self.message_registry.get_responses(), [devices_list_response]) self.assertFalse(self.client.completed_event.is_set()) - - def test_on_message_shared_topic_when_not_authenticated(self) -> None: - login_response = LoginResponseFactory(private_key=b"test_private_key") - devices_list_response = DevicesListResponseFactory() - encrypted_message = encrypt_message(devices_list_response, login_response.private_key) - message = Mock(payload=encrypted_message, topic="clients") - - self.client.on_message_callback(self.client.mqtt_client, None, message)