Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not record messages in registry that are not expected. #32

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions mobilus_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,16 @@ def on_subscribe_callback(self, _client: mqtt.Client, _userdata: None, _mid: int
def on_message_callback(self, _client: mqtt.Client, _userdata: None, mqtt_message: mqtt.MQTTMessage) -> None:
logger.info("Received message on topic - %s", mqtt_message.topic)

# Ignore messages from the shared topic sent by other clients until the client is authenticated.
if mqtt_message.topic == self.shared_topic and not self.authenticated_event.is_set():
logger.info("Client is not authenticated yet. Ignoring message.")
return

message = MessageEncryptor.decrypt(mqtt_message.payload, self.key_registry)
logger.info("Decrypted message - %s", type(message).__name__)

if message is None:
logger.info("Failed to decrypt message, ignoring")
return

status = MessageValidator.validate(message)

if status != MessageStatus.SUCCESS or message is None:
if status != MessageStatus.SUCCESS:
logger.error("Message - %s returned an error - %s", type(message).__name__, status.name)
self.terminate()
return
Expand All @@ -116,7 +115,7 @@ def on_message_callback(self, _client: mqtt.Client, _userdata: None, mqtt_messag
if isinstance(message, LoginResponse):
self.key_registry.register_keys(message)
self.authenticated_event.set()
else:
elif self.message_registry.is_expected_response(message):
self.message_registry.register_response(message)

if self.message_registry.all_responses_received():
Expand Down
3 changes: 3 additions & 0 deletions mobilus_client/messages/encryptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ def decrypt(encrypted_message: bytes, key_registry: KeyRegistry) -> MessageRespo
# Choose proper decryption key
key = key_registry.get_decryption_key(message_klass)

if not key:
return None

# Decrypt body
iv = create_iv(timestamp)
body = decrypt_body(key, iv, encrypted_body)
Expand Down
2 changes: 2 additions & 0 deletions mobilus_client/registries/key.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ class KeyRegistry:
def __init__(self, user_key: bytes) -> None:
self._registry = {
"user_key": user_key,
"private_key": b"",
"public_key": b"",
}

def get_keys(self) -> dict[str, bytes]:
Expand Down
6 changes: 6 additions & 0 deletions mobilus_client/registries/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ def get_requests(self) -> list[MessageRequest]:
def get_responses(self) -> list[MessageResponse]:
return self._responses

def is_expected_response(self, message_response: MessageResponse) -> bool:
return any(
isinstance(message_response, self.MESSAGE_MAP[type(request)])
for request in self.get_requests()
)

def all_responses_received(self) -> bool:
expected_responses = Counter(self.MESSAGE_MAP[type(request)] for request in self.get_requests())
actual_responses = Counter(type(response) for response in self.get_responses())
Expand Down
9 changes: 9 additions & 0 deletions tests/messages/test_encryptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,12 @@ def test_decrypt_invalid_message_with_invalid_key(self) -> None:
sys.stdout.write(f"result:{type(result)}:")

self.assertIsNone(result)

def test_decrypt_invalid_message_with_no_public_key(self) -> None:
self.key_registry = KeyRegistry(self.user_key)
message = CallEventsRequestFactory()
encrypted_message = encrypt_message(message, b"test_invalid_key")

result = MessageEncryptor.decrypt(encrypted_message, self.key_registry)

self.assertIsNone(result)
35 changes: 22 additions & 13 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,13 @@ def test_on_subscribe_callback(self, mock_send_request: Mock) -> None:
mock_send_request.assert_called_once_with(
self.client, "login", login=self.config.user_login, password=self.config.user_key)

@patch.object(mqtt.Client, "disconnect", return_value=Mock(), autospec=True)
def test_on_message_callback_invalid(self, mock_disconnect: Mock) -> None:
@patch("logging.Logger.info", return_value=Mock(), autospec=True)
def test_on_message_callback_invalid(self, mock_info: Mock) -> None:
message = Mock(payload=b"invalid")

self.client.on_message_callback(self.client.mqtt_client, None, message)

mock_disconnect.assert_called_once()
mock_info.assert_called_with(ANY, "Failed to decrypt message, ignoring")

def test_on_message_login_response(self) -> None:
login_response = LoginResponseFactory()
Expand All @@ -209,6 +209,17 @@ def test_on_message_login_response(self) -> None:
"public_key": login_response.public_key,
})

@patch.object(Client, "terminate", return_value=Mock(), autospec=True)
def test_on_message_login_response_unauthenticated(self, mock_terminate: Mock) -> None:
login_response = LoginResponseFactory(failed=True)
encrypted_message = encrypt_message(login_response, self.config.user_key)
message = Mock(payload=encrypted_message)

self.client.on_message_callback(self.client.mqtt_client, None, message)

self.assertFalse(self.client.authenticated_event.is_set())
mock_terminate.assert_called_once()

def test_on_message_call_events_request_all_completed(self) -> None:
login_response = LoginResponseFactory(public_key=b"test_public_key_")
self.key_registry.register_keys(login_response)
Expand All @@ -222,21 +233,27 @@ def test_on_message_call_events_request_all_completed(self) -> None:
self.assertEqual(self.message_registry.get_responses(), [call_events_request])
self.assertTrue(self.client.completed_event.is_set())

def test_on_message_current_state_response_not_all_completed(self) -> None:
def test_on_message_current_state_response_all_completed(self) -> None:
login_response = LoginResponseFactory(private_key=b"test_private_key")
self.key_registry.register_keys(login_response)
current_state_request = CurrentStateRequest()
self.message_registry.register_request(current_state_request)

current_state_response = CurrentStateResponseFactory()
encrypted_message = encrypt_message(current_state_response, login_response.private_key)
message = Mock(payload=encrypted_message)

self.client.on_message_callback(self.client.mqtt_client, None, message)
self.assertEqual(self.message_registry.get_responses(), [current_state_response])
self.assertFalse(self.client.completed_event.is_set())
self.assertTrue(self.client.completed_event.is_set())

def test_on_message_devices_list_response_not_all_completed(self) -> None:
login_response = LoginResponseFactory(private_key=b"test_private_key")
self.key_registry.register_keys(login_response)
device_list_request = DevicesListRequest()
current_state_request = CurrentStateRequest()
self.message_registry.register_request(device_list_request)
self.message_registry.register_request(current_state_request)

devices_list_response = DevicesListResponseFactory()
encrypted_message = encrypt_message(devices_list_response, login_response.private_key)
Expand All @@ -245,11 +262,3 @@ def test_on_message_devices_list_response_not_all_completed(self) -> None:
self.client.on_message_callback(self.client.mqtt_client, None, message)
self.assertEqual(self.message_registry.get_responses(), [devices_list_response])
self.assertFalse(self.client.completed_event.is_set())

def test_on_message_shared_topic_when_not_authenticated(self) -> None:
login_response = LoginResponseFactory(private_key=b"test_private_key")
devices_list_response = DevicesListResponseFactory()
encrypted_message = encrypt_message(devices_list_response, login_response.private_key)
message = Mock(payload=encrypted_message, topic="clients")

self.client.on_message_callback(self.client.mqtt_client, None, message)