Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add prefetch to MQTT loader #659

Merged
merged 12 commits into from
Apr 15, 2024
57 changes: 41 additions & 16 deletions dissect/target/loaders/mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def _read(self, offset: int, length: int, optimization_strategy: int = 0) -> byt
class MQTTConnection:
broker = None
host = None
prev = -1
factor = 1
prefetch_factor_inc = 10

def __init__(self, broker: Broker, host: str):
self.broker = broker
Expand Down Expand Up @@ -95,20 +98,32 @@ def info(self) -> list[MQTTStream]:

def read(self, disk_id: int, offset: int, length: int, optimization_strategy: int) -> bytes:
message = None
self.broker.seek(self.host, disk_id, offset, length, optimization_strategy)

message = self.broker.read(self.host, disk_id, offset, length)
if message:
return message.data

if self.prev == offset - (length * self.factor):
if self.factor < 500:
self.factor += self.prefetch_factor_inc
else:
self.factor = 1

self.prev = offset
flength = length * self.factor
self.broker.factor = self.factor
self.broker.seek(self.host, disk_id, offset, flength, optimization_strategy)
attempts = 0
while True:
message = self.broker.read(self.host, disk_id, offset, length)
# don't waste time with sleep if we have a response
if message:
if message := self.broker.read(self.host, disk_id, offset, length):
# don't waste time with sleep if we have a response
break

attempts += 1
time.sleep(0.01)
if attempts > 100:
time.sleep(0.1)
if attempts > 300:
# message might have not reached agent, resend...
self.broker.seek(self.host, disk_id, offset, length, optimization_strategy)
self.broker.seek(self.host, disk_id, offset, flength, optimization_strategy)
attempts = 0

return message.data
Expand All @@ -127,6 +142,7 @@ class Broker:
diskinfo = {}
index = {}
topo = {}
factor = 1

def __init__(self, broker: Broker, port: str, key: str, crt: str, ca: str, case: str, **kwargs):
self.broker_host = broker
Expand All @@ -137,10 +153,13 @@ def __init__(self, broker: Broker, port: str, key: str, crt: str, ca: str, case:
self.case = case
self.command = kwargs.get("command", None)

def clear_cache(self) -> None:
self.index = {}

@suppress
def read(self, host: str, disk_id: int, seek_address: int, read_length: int) -> SeekMessage:
key = f"{host}-{disk_id}-{seek_address}-{read_length}"
return self.index.pop(key)
return self.index.get(key)

@suppress
def disk(self, host: str) -> DiskMessage:
Expand All @@ -165,14 +184,15 @@ def _on_read(self, hostname: str, tokens: list[str], payload: bytes) -> None:
disk_id = tokens[3]
seek_address = int(tokens[4], 16)
read_length = int(tokens[5], 16)
msg = SeekMessage(data=payload)

key = f"{hostname}-{disk_id}-{seek_address}-{read_length}"
for i in range(self.factor):
sublength = int(read_length / self.factor)
start = i * sublength
key = f"{hostname}-{disk_id}-{seek_address+start}-{sublength}"
if key in self.index:
continue

if key in self.index:
return

self.index[key] = msg
self.index[key] = SeekMessage(data=payload[start : start + sublength])

def _on_id(self, hostname: str, payload: bytes) -> None:
key = hostname
Expand Down Expand Up @@ -204,9 +224,14 @@ def _on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.client.MQTTM
elif response == "ID":
self._on_id(hostname, msg.payload)

def seek(self, host: str, disk_id: int, offset: int, length: int, optimization_strategy: int) -> None:
def seek(self, host: str, disk_id: int, offset: int, flength: int, optimization_strategy: int) -> None:
length = int(flength / self.factor)
key = f"{host}-{disk_id}-{offset}-{length}"
if key in self.index:
return

self.mqtt_client.publish(
f"{self.case}/{host}/SEEK/{disk_id}/{hex(offset)}/{hex(length)}", pack("<I", optimization_strategy)
f"{self.case}/{host}/SEEK/{disk_id}/{hex(offset)}/{hex(flength)}", pack("<I", optimization_strategy)
)

def info(self, host: str) -> None:
Expand Down
3 changes: 1 addition & 2 deletions dissect/target/tools/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,7 @@ def main():
collected_plugins = {}

if targets:
for target in targets:
plugin_target = Target.open(target)
for plugin_target in Target.open_all(targets, args.children):
if isinstance(plugin_target._loader, ProxyLoader):
parser.error("can't list compatible plugins for remote targets.")
funcs, _ = find_plugin_functions(plugin_target, args.list, compatibility=True, show_hidden=True)
Expand Down
44 changes: 44 additions & 0 deletions tests/loaders/test_mqtt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

import sys
import time
from dataclasses import dataclass
from struct import pack
from typing import Iterator
from unittest.mock import MagicMock, patch
Expand Down Expand Up @@ -44,6 +47,24 @@ def publish(self, topic: str, *args) -> None:
self.on_message(self, None, response)


@dataclass
class MockSeekMessage:
data: bytes = b""


class MockBroker(MagicMock):
_seek = False

def seek(self, *args) -> None:
self._seek = True

def read(self, *args) -> MockSeekMessage | None:
if self._seek:
self._seek = False
return MockSeekMessage(data=b"010101")
return None


@pytest.fixture
def mock_paho(monkeypatch: pytest.MonkeyPatch) -> Iterator[MagicMock]:
with monkeypatch.context() as m:
Expand All @@ -62,6 +83,11 @@ def mock_client(mock_paho: MagicMock) -> Iterator[MagicMock]:
yield mock_client


@pytest.fixture
def mock_broker() -> Iterator[MockBroker]:
yield MockBroker()


@pytest.mark.parametrize(
"alias, hosts, disks, disk, seek, read, expected",
[
Expand Down Expand Up @@ -102,3 +128,21 @@ def test_remote_loader_stream(
target.disks[disk].seek(seek)
data = target.disks[disk].read(read)
assert data == expected


def test_mqtt_loader_prefetch(mock_broker: MockBroker) -> None:
from dissect.target.loaders.mqtt import MQTTConnection

connection = MQTTConnection(mock_broker, "")
connection.prefetch_factor_inc = 10
assert connection.factor == 1
assert connection.prev == -1
connection.read(1, 0, 100, 0)
assert connection.factor == 1
assert connection.prev == 0
connection.read(1, 100, 100, 0)
assert connection.factor == connection.prefetch_factor_inc + 1
assert connection.prev == 100
connection.read(1, 1200, 100, 0)
assert connection.factor == (connection.prefetch_factor_inc * 2) + 1
assert connection.prev == 1200