diff --git a/wgkex/worker/mqtt.py b/wgkex/worker/mqtt.py index 24d9b7e..53a06be 100644 --- a/wgkex/worker/mqtt.py +++ b/wgkex/worker/mqtt.py @@ -5,7 +5,7 @@ import re import socket import threading -from typing import Optional, Dict, Any, Union +from typing import Optional, Any import paho.mqtt.client as mqtt diff --git a/wgkex/worker/mqtt_test.py b/wgkex/worker/mqtt_test.py index c960142..bea436f 100644 --- a/wgkex/worker/mqtt_test.py +++ b/wgkex/worker/mqtt_test.py @@ -1,8 +1,13 @@ """Unit tests for mqtt.py""" +import socket import threading import unittest +from time import sleep + import mock +import paho.mqtt.client +from wgkex.common.mqtt import TOPIC_CONNECTED_PEERS from wgkex.worker import mqtt @@ -88,6 +93,43 @@ def test_on_message_wireguard_fails_no_domain(self, config_mock, link_mock): with self.assertRaises(ValueError): mqtt.on_message_wireguard(None, None, mqtt_msg) + @mock.patch.object(mqtt, "get_config") + @mock.patch.object(mqtt, "get_connected_peers_count") + def test_publish_metrics_loop_success(self, conn_peers_mock, config_mock): + config_mock.return_value = _get_config_mock() + conn_peers_mock.return_value = 20 + mqtt_client = mock.MagicMock(spec=paho.mqtt.client.Client) + + ee = threading.Event() + thread = threading.Thread( + target=mqtt.publish_metrics_loop, + args=(ee, mqtt_client, "_ffmuc_domain.one"), + ) + thread.start() + + i = 0 + while i < 20 and not mqtt_client.publish.called: + i += 1 + sleep(0.1) + + conn_peers_mock.assert_called_with("wg-domain.one") + mqtt_client.publish.assert_called_with( + TOPIC_CONNECTED_PEERS.format( + domain="_ffmuc_domain.one", worker=socket.gethostname() + ), + 20, + retain=True, + ) + + ee.set() + + i = 0 + while i < 20 and thread.is_alive(): + i += 1 + sleep(0.1) + + self.assertFalse(thread.is_alive()) + if __name__ == "__main__": unittest.main()