Skip to content

Commit

Permalink
Simplify mocking of mqqt client
Browse files Browse the repository at this point in the history
This way we do not have scope imports from loaders.mqtt to the test body.
  • Loading branch information
twiggler committed Sep 18, 2024
1 parent 97349d8 commit 0c079ca
Showing 1 changed file with 7 additions and 24 deletions.
31 changes: 7 additions & 24 deletions tests/loaders/test_mqtt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import argparse
import sys
import time
from dataclasses import dataclass
from struct import pack
Expand All @@ -11,6 +10,7 @@
import pytest

from dissect.target import Target
from dissect.target.loaders.mqtt import Broker, MQTTConnection, case, host_name


class MQTTMock(MagicMock):
Expand Down Expand Up @@ -69,21 +69,10 @@ def read(self, *args) -> MockSeekMessage | None:


@pytest.fixture
def mock_paho(monkeypatch: pytest.MonkeyPatch) -> Iterator[MagicMock]:
with monkeypatch.context() as m:
mock_paho = MagicMock()
m.setitem(sys.modules, "paho", mock_paho)
m.setitem(sys.modules, "paho.mqtt", mock_paho.mqtt)
m.setitem(sys.modules, "paho.mqtt.client", mock_paho.mqtt.client)
def mock_client(monkeypatch: pytest.MonkeyPatch) -> None:
import paho.mqtt.client as mqtt

yield mock_paho


@pytest.fixture
def mock_client(mock_paho: MagicMock) -> Iterator[MagicMock]:
mock_client = MQTTMock()
mock_paho.mqtt.client.Client.return_value = mock_client
yield mock_client
monkeypatch.setattr(mqtt, "Client", MQTTMock)


@pytest.fixture
Expand All @@ -104,7 +93,7 @@ def mock_broker() -> Iterator[MockBroker]:
@patch.object(time, "sleep") # improve speed during test, no need to wait for peers
def test_remote_loader_stream(
time: MagicMock,
mock_client: MagicMock,
mock_client: None,
alias: str,
hosts: list[str],
disks: list[int],
Expand All @@ -113,8 +102,6 @@ def test_remote_loader_stream(
read: int,
expected: bytes,
) -> None:
from dissect.target.loaders.mqtt import Broker

broker = Broker("0.0.0.0", "1884", "key", "crt", "ca", "case1", "user", "pass")
broker.connect()
broker.mqtt_client.fill_disks(disks)
Expand All @@ -134,8 +121,6 @@ def test_remote_loader_stream(


def test_mqtt_loader_prefetch(mock_broker: MockBroker) -> None:
from dissect.target.loaders.mqtt import MQTTConnection

connection = MQTTConnection(mock_broker, "")
connection.prefetch_factor_inc = 10
assert connection.factor == 1
Expand Down Expand Up @@ -174,6 +159,7 @@ def generate_longest_valid_hostname():
("example-label.com", True),
("example..com", False),
(generate_longest_valid_hostname(), True),
(generate_longest_valid_hostname() + 'a', False),
],
ids=[
"valid_domain",
Expand All @@ -188,11 +174,10 @@ def generate_longest_valid_hostname():
"valid_domain_with_hyphen",
"invalid_empty_label",
"valid_max_length",
"too_long"
],
)
def test_host_name_parser(hostname: str, is_valid_hostname: bool) -> None:
from dissect.target.loaders.mqtt import host_name

assert host_name(hostname) == is_valid_hostname


Expand Down Expand Up @@ -220,8 +205,6 @@ def test_host_name_parser(hostname: str, is_valid_hostname: bool) -> None:
],
)
def test_case(case_name, parse_result: str | pytest.RaisesContext[argparse.ArgumentTypeError]) -> None:
from dissect.target.loaders.mqtt import case

if isinstance(parse_result, str):
assert case(case_name) == parse_result
else:
Expand Down

0 comments on commit 0c079ca

Please sign in to comment.